From f772d95291c01e40ac73f1364fec31255d25fac9 Mon Sep 17 00:00:00 2001
From: Neta Zmora <neta.zmora@intel.com>
Date: Wed, 25 Jul 2018 18:41:20 +0300
Subject: [PATCH] compress_classifier.py: code refactoring

We are using this file for more and more use-cases and we need to keep
it readable and clean.
I've tried to move code that is not in the main control-path to
specific functions.
---
 .../compress_classifier.py                    | 134 ++++++++++--------
 1 file changed, 77 insertions(+), 57 deletions(-)

diff --git a/examples/classifier_compression/compress_classifier.py b/examples/classifier_compression/compress_classifier.py
index 4a9234c..55f9b19 100755
--- a/examples/classifier_compression/compress_classifier.py
+++ b/examples/classifier_compression/compress_classifier.py
@@ -235,35 +235,11 @@ def main():
     msglogger.info('Optimizer Args: %s', optimizer.defaults)
 
     if args.ADC:
-        import examples.automated_deep_compression.ADC as ADC
-        HAVE_COACH_INSTALLED = True
-        if not HAVE_COACH_INSTALLED:
-            raise ValueError("ADC is currently experimental and uses non-public Coach features")
-
-        train_loader, val_loader, test_loader, _ = apputils.load_data(
-            args.dataset, os.path.expanduser(args.data), args.batch_size,
-            args.workers, args.validation_size, args.deterministic)
-
-        args.display_confusion = True
-        validate_fn = partial(validate, val_loader=test_loader, criterion=criterion,
-                              loggers=[pylogger], args=args)
-
-        if args.ADC_params is not None:
-            ADC.summarize_experiment(args.ADC_params, args.dataset, args.arch, validate_fn)
-            exit()
-
-        save_checkpoint_fn = partial(apputils.save_checkpoint, arch=args.arch, dir=msglogger.logdir)
-        ADC.do_adc(model, args.dataset, args.arch, val_loader, validate_fn, save_checkpoint_fn)
-        exit()
+        return automated_deep_compression(model, criterion, pylogger, args)
 
     # This sample application can be invoked to produce various summary reports.
     if args.summary:
-        which_summary = args.summary
-        if which_summary.startswith('png'):
-            apputils.draw_img_classifier_to_file(model, 'model.png', args.dataset, which_summary == 'png_w_params')
-        else:
-            distiller.model_summary(model, which_summary, args.dataset)
-        exit()
+        return summarize_model(model, args.dataset, which_summary=args.summary)
 
     # Load the datasets: the dataset to load is inferred from the model name passed
     # in args.arch.  The default dataset is ImageNet, but if args.arch contains the
@@ -282,39 +258,10 @@ def main():
         activations_sparsity = ActivationSparsityCollector(model)
 
     if args.sensitivity is not None:
-        # This sample application can be invoked to execute Sensitivity Analysis on your
-        # model.  The ouptut is saved to CSV and PNG.
-        msglogger.info("Running sensitivity tests")
-        test_fnc = partial(test, test_loader=test_loader, criterion=criterion,
-                           loggers=[pylogger], args=args)
-        which_params = [param_name for param_name, _ in model.named_parameters()]
-        sensitivity = distiller.perform_sensitivity_analysis(model,
-                                                             net_params=which_params,
-                                                             sparsities=np.arange(0.0, 0.95, 0.05),
-                                                             test_func=test_fnc,
-                                                             group=args.sensitivity)
-        distiller.sensitivities_to_png(sensitivity, 'sensitivity.png')
-        distiller.sensitivities_to_csv(sensitivity, 'sensitivity.csv')
-        exit()
+        return sensitivity_analysis(model, criterion, test_loader, pylogger, args)
 
     if args.evaluate:
-        # This sample application can be invoked to evaluate the accuracy of your model on
-        # the test dataset.
-        # You can optionally quantize the model to 8-bit integer before evaluation.
-        # For example:
-        # python3 compress_classifier.py --arch resnet20_cifar  ../data.cifar10 -p=50 --resume=checkpoint.pth.tar --evaluate
-        if args.quantize:
-            model.cpu()
-            quantizer = quantization.SymmetricLinearQuantizer(model, 8, 8)
-            quantizer.prepare_model()
-            model.cuda()
-        top1, _, _ = test(test_loader, model, criterion, [pylogger], args=args)
-        if args.quantize:
-            checkpoint_name = 'quantized'
-            apputils.save_checkpoint(0, args.arch, model, optimizer=None, best_top1=top1,
-                                     name='_'.split(args.name, checkpoint_name) if args.name else checkpoint_name,
-                                     dir=msglogger.logdir)
-        exit()
+        return evaluate_model(model, criterion, test_loader, pylogger, args)
 
     if args.compress:
         # The main use-case for this sample application is CNN compression. Compression
@@ -532,6 +479,79 @@ def get_inference_var(tensor):
     return torch.autograd.Variable(tensor, volatile=True)
 
 
+def evaluate_model(model, criterion, test_loader, loggers, args):
+    # This sample application can be invoked to evaluate the accuracy of your model on
+    # the test dataset.
+    # You can optionally quantize the model to 8-bit integer before evaluation.
+    # For example:
+    # python3 compress_classifier.py --arch resnet20_cifar  ../data.cifar10 -p=50 --resume=checkpoint.pth.tar --evaluate
+
+    if not isinstance(loggers, list):
+        loggers = [loggers]
+
+    if args.quantize:
+        model.cpu()
+        quantizer = quantization.SymmetricLinearQuantizer(model, 8, 8)
+        quantizer.prepare_model()
+        model.cuda()
+    top1, _, _ = test(test_loader, model, criterion, loggers, args=args)
+    if args.quantize:
+        checkpoint_name = 'quantized'
+        apputils.save_checkpoint(0, args.arch, model, optimizer=None, best_top1=top1,
+                                 name='_'.split(args.name, checkpoint_name) if args.name else checkpoint_name,
+                                 dir=msglogger.logdir)
+
+
+def summarize_model(model, dataset, which_summary):
+    if which_summary.startswith('png'):
+        apputils.draw_img_classifier_to_file(model, 'model.png', dataset, which_summary == 'png_w_params')
+    else:
+        distiller.model_summary(model, which_summary, dataset)
+
+
+def sensitivity_analysis(model, criterion, data_loader, loggers, args):
+    # This sample application can be invoked to execute Sensitivity Analysis on your
+    # model.  The ouptut is saved to CSV and PNG.
+    msglogger.info("Running sensitivity tests")
+    if not isinstance(loggers, list):
+        loggers = [loggers]
+    test_fnc = partial(test, test_loader=data_loader, criterion=criterion,
+                       loggers=loggers, args=args)
+    which_params = [param_name for param_name, _ in model.named_parameters()]
+    sensitivity = distiller.perform_sensitivity_analysis(model,
+                                                         net_params=which_params,
+                                                         sparsities=np.arange(0.0, 0.95, 0.05),
+                                                         test_func=test_fnc,
+                                                         group=args.sensitivity)
+    distiller.sensitivities_to_png(sensitivity, 'sensitivity.png')
+    distiller.sensitivities_to_csv(sensitivity, 'sensitivity.csv')
+
+
+def automated_deep_compression(model, criterion, loggers, args):
+    import examples.automated_deep_compression.ADC as ADC
+    HAVE_COACH_INSTALLED = True
+    if not HAVE_COACH_INSTALLED:
+        raise ValueError("ADC is currently experimental and uses non-public Coach features")
+
+    if not isinstance(loggers, list):
+        loggers = [loggers]
+
+    train_loader, val_loader, test_loader, _ = apputils.load_data(
+        args.dataset, os.path.expanduser(args.data), args.batch_size,
+        args.workers, args.validation_size, args.deterministic)
+
+    args.display_confusion = True
+    validate_fn = partial(validate, val_loader=test_loader, criterion=criterion,
+                          loggers=loggers, args=args)
+
+    if args.ADC_params is not None:
+        ADC.summarize_experiment(args.ADC_params, args.dataset, args.arch, validate_fn)
+        exit()
+
+    save_checkpoint_fn = partial(apputils.save_checkpoint, arch=args.arch, dir=msglogger.logdir)
+    ADC.do_adc(model, args.dataset, args.arch, val_loader, validate_fn, save_checkpoint_fn)
+
+
 if __name__ == '__main__':
     try:
         main()
-- 
GitLab