From f772d95291c01e40ac73f1364fec31255d25fac9 Mon Sep 17 00:00:00 2001 From: Neta Zmora <neta.zmora@intel.com> Date: Wed, 25 Jul 2018 18:41:20 +0300 Subject: [PATCH] compress_classifier.py: code refactoring We are using this file for more and more use-cases and we need to keep it readable and clean. I've tried to move code that is not in the main control-path to specific functions. --- .../compress_classifier.py | 134 ++++++++++-------- 1 file changed, 77 insertions(+), 57 deletions(-) diff --git a/examples/classifier_compression/compress_classifier.py b/examples/classifier_compression/compress_classifier.py index 4a9234c..55f9b19 100755 --- a/examples/classifier_compression/compress_classifier.py +++ b/examples/classifier_compression/compress_classifier.py @@ -235,35 +235,11 @@ def main(): msglogger.info('Optimizer Args: %s', optimizer.defaults) if args.ADC: - import examples.automated_deep_compression.ADC as ADC - HAVE_COACH_INSTALLED = True - if not HAVE_COACH_INSTALLED: - raise ValueError("ADC is currently experimental and uses non-public Coach features") - - train_loader, val_loader, test_loader, _ = apputils.load_data( - args.dataset, os.path.expanduser(args.data), args.batch_size, - args.workers, args.validation_size, args.deterministic) - - args.display_confusion = True - validate_fn = partial(validate, val_loader=test_loader, criterion=criterion, - loggers=[pylogger], args=args) - - if args.ADC_params is not None: - ADC.summarize_experiment(args.ADC_params, args.dataset, args.arch, validate_fn) - exit() - - save_checkpoint_fn = partial(apputils.save_checkpoint, arch=args.arch, dir=msglogger.logdir) - ADC.do_adc(model, args.dataset, args.arch, val_loader, validate_fn, save_checkpoint_fn) - exit() + return automated_deep_compression(model, criterion, pylogger, args) # This sample application can be invoked to produce various summary reports. if args.summary: - which_summary = args.summary - if which_summary.startswith('png'): - apputils.draw_img_classifier_to_file(model, 'model.png', args.dataset, which_summary == 'png_w_params') - else: - distiller.model_summary(model, which_summary, args.dataset) - exit() + return summarize_model(model, args.dataset, which_summary=args.summary) # Load the datasets: the dataset to load is inferred from the model name passed # in args.arch. The default dataset is ImageNet, but if args.arch contains the @@ -282,39 +258,10 @@ def main(): activations_sparsity = ActivationSparsityCollector(model) if args.sensitivity is not None: - # This sample application can be invoked to execute Sensitivity Analysis on your - # model. The ouptut is saved to CSV and PNG. - msglogger.info("Running sensitivity tests") - test_fnc = partial(test, test_loader=test_loader, criterion=criterion, - loggers=[pylogger], args=args) - which_params = [param_name for param_name, _ in model.named_parameters()] - sensitivity = distiller.perform_sensitivity_analysis(model, - net_params=which_params, - sparsities=np.arange(0.0, 0.95, 0.05), - test_func=test_fnc, - group=args.sensitivity) - distiller.sensitivities_to_png(sensitivity, 'sensitivity.png') - distiller.sensitivities_to_csv(sensitivity, 'sensitivity.csv') - exit() + return sensitivity_analysis(model, criterion, test_loader, pylogger, args) if args.evaluate: - # This sample application can be invoked to evaluate the accuracy of your model on - # the test dataset. - # You can optionally quantize the model to 8-bit integer before evaluation. - # For example: - # python3 compress_classifier.py --arch resnet20_cifar ../data.cifar10 -p=50 --resume=checkpoint.pth.tar --evaluate - if args.quantize: - model.cpu() - quantizer = quantization.SymmetricLinearQuantizer(model, 8, 8) - quantizer.prepare_model() - model.cuda() - top1, _, _ = test(test_loader, model, criterion, [pylogger], args=args) - if args.quantize: - checkpoint_name = 'quantized' - apputils.save_checkpoint(0, args.arch, model, optimizer=None, best_top1=top1, - name='_'.split(args.name, checkpoint_name) if args.name else checkpoint_name, - dir=msglogger.logdir) - exit() + return evaluate_model(model, criterion, test_loader, pylogger, args) if args.compress: # The main use-case for this sample application is CNN compression. Compression @@ -532,6 +479,79 @@ def get_inference_var(tensor): return torch.autograd.Variable(tensor, volatile=True) +def evaluate_model(model, criterion, test_loader, loggers, args): + # This sample application can be invoked to evaluate the accuracy of your model on + # the test dataset. + # You can optionally quantize the model to 8-bit integer before evaluation. + # For example: + # python3 compress_classifier.py --arch resnet20_cifar ../data.cifar10 -p=50 --resume=checkpoint.pth.tar --evaluate + + if not isinstance(loggers, list): + loggers = [loggers] + + if args.quantize: + model.cpu() + quantizer = quantization.SymmetricLinearQuantizer(model, 8, 8) + quantizer.prepare_model() + model.cuda() + top1, _, _ = test(test_loader, model, criterion, loggers, args=args) + if args.quantize: + checkpoint_name = 'quantized' + apputils.save_checkpoint(0, args.arch, model, optimizer=None, best_top1=top1, + name='_'.split(args.name, checkpoint_name) if args.name else checkpoint_name, + dir=msglogger.logdir) + + +def summarize_model(model, dataset, which_summary): + if which_summary.startswith('png'): + apputils.draw_img_classifier_to_file(model, 'model.png', dataset, which_summary == 'png_w_params') + else: + distiller.model_summary(model, which_summary, dataset) + + +def sensitivity_analysis(model, criterion, data_loader, loggers, args): + # This sample application can be invoked to execute Sensitivity Analysis on your + # model. The ouptut is saved to CSV and PNG. + msglogger.info("Running sensitivity tests") + if not isinstance(loggers, list): + loggers = [loggers] + test_fnc = partial(test, test_loader=data_loader, criterion=criterion, + loggers=loggers, args=args) + which_params = [param_name for param_name, _ in model.named_parameters()] + sensitivity = distiller.perform_sensitivity_analysis(model, + net_params=which_params, + sparsities=np.arange(0.0, 0.95, 0.05), + test_func=test_fnc, + group=args.sensitivity) + distiller.sensitivities_to_png(sensitivity, 'sensitivity.png') + distiller.sensitivities_to_csv(sensitivity, 'sensitivity.csv') + + +def automated_deep_compression(model, criterion, loggers, args): + import examples.automated_deep_compression.ADC as ADC + HAVE_COACH_INSTALLED = True + if not HAVE_COACH_INSTALLED: + raise ValueError("ADC is currently experimental and uses non-public Coach features") + + if not isinstance(loggers, list): + loggers = [loggers] + + train_loader, val_loader, test_loader, _ = apputils.load_data( + args.dataset, os.path.expanduser(args.data), args.batch_size, + args.workers, args.validation_size, args.deterministic) + + args.display_confusion = True + validate_fn = partial(validate, val_loader=test_loader, criterion=criterion, + loggers=loggers, args=args) + + if args.ADC_params is not None: + ADC.summarize_experiment(args.ADC_params, args.dataset, args.arch, validate_fn) + exit() + + save_checkpoint_fn = partial(apputils.save_checkpoint, arch=args.arch, dir=msglogger.logdir) + ADC.do_adc(model, args.dataset, args.arch, val_loader, validate_fn, save_checkpoint_fn) + + if __name__ == '__main__': try: main() -- GitLab