From acbb4b4d15dfbd722cbc9c5ebb36e688ef4ef02e Mon Sep 17 00:00:00 2001
From: Neta Zmora <31280975+nzmora@users.noreply.github.com>
Date: Thu, 22 Nov 2018 10:58:45 +0200
Subject: [PATCH] Fix Issue 79 (#81)

* Fix issue #79

Change the default values so that the following scheduler meta-data keys
are always defined: 'starting_epoch', 'ending_epoch', 'frequency'

* compress_classifier.py: add a new argument

Allow the specification, from the command line arguments,  of the range of
pruning levels scanned when doing sensitivity analysis

* Add regression test for issue #79
---
 distiller/scheduler.py                        |  2 +-
 .../compress_classifier.py                    | 10 ++++++---
 tests/full_flow_tests.py                      | 21 +++++++++++++++++--
 3 files changed, 27 insertions(+), 6 deletions(-)
 mode change 100644 => 100755 tests/full_flow_tests.py

diff --git a/distiller/scheduler.py b/distiller/scheduler.py
index 11b7e69..8b26eeb 100755
--- a/distiller/scheduler.py
+++ b/distiller/scheduler.py
@@ -83,7 +83,7 @@ class CompressionScheduler(object):
             masker = ParameterMasker(name)
             self.zeros_mask_dict[name] = masker
 
-    def add_policy(self, policy, epochs=None, starting_epoch=None, ending_epoch=None, frequency=None):
+    def add_policy(self, policy, epochs=None, starting_epoch=0, ending_epoch=1, frequency=1):
         """Add a new policy to the schedule.
 
         Args:
diff --git a/examples/classifier_compression/compress_classifier.py b/examples/classifier_compression/compress_classifier.py
index 7d70958..d56c50f 100755
--- a/examples/classifier_compression/compress_classifier.py
+++ b/examples/classifier_compression/compress_classifier.py
@@ -132,6 +132,9 @@ parser.add_argument('--compress', dest='compress', type=str, nargs='?', action='
                     help='configuration file for pruning the model (default is to use hard-coded schedule)')
 parser.add_argument('--sense', dest='sensitivity', choices=['element', 'filter', 'channel'],
                     help='test the sensitivity of layers to pruning')
+parser.add_argument('--sense-range', dest='sensitivity_range', type=float, nargs=3, default=[0.0, 0.95, 0.05],
+                    help='an optional paramaeter for sensitivity testing providing the range of sparsities to test.\n'
+                    'This is equaivalent to creating sensitivities = np.arange(start, stop, step)')
 parser.add_argument('--extras', default=None, type=str,
                     help='file with extra configuration information')
 parser.add_argument('--deterministic', '--det', action='store_true',
@@ -338,7 +341,8 @@ def main():
     activations_collectors = create_activation_stats_collectors(model, collection_phase=args.activation_stats)
 
     if args.sensitivity is not None:
-        return sensitivity_analysis(model, criterion, test_loader, pylogger, args)
+        sensitivities = np.arange(args.sensitivity_range[0], args.sensitivity_range[1], args.sensitivity_range[2])
+        return sensitivity_analysis(model, criterion, test_loader, pylogger, args, sensitivities)
 
     if args.evaluate:
         return evaluate_model(model, criterion, test_loader, pylogger, activations_collectors, args)
@@ -735,7 +739,7 @@ def summarize_model(model, dataset, which_summary):
         distiller.model_summary(model, which_summary, dataset)
 
 
-def sensitivity_analysis(model, criterion, data_loader, loggers, args):
+def sensitivity_analysis(model, criterion, data_loader, loggers, args, sparsities):
     # This sample application can be invoked to execute Sensitivity Analysis on your
     # model.  The ouptut is saved to CSV and PNG.
     msglogger.info("Running sensitivity tests")
@@ -747,7 +751,7 @@ def sensitivity_analysis(model, criterion, data_loader, loggers, args):
     which_params = [param_name for param_name, _ in model.named_parameters()]
     sensitivity = distiller.perform_sensitivity_analysis(model,
                                                          net_params=which_params,
-                                                         sparsities=np.arange(0.0, 0.95, 0.05),
+                                                         sparsities=sparsities,
                                                          test_func=test_fnc,
                                                          group=args.sensitivity)
     distiller.sensitivities_to_png(sensitivity, 'sensitivity.png')
diff --git a/tests/full_flow_tests.py b/tests/full_flow_tests.py
old mode 100644
new mode 100755
index 889511f..f23c1c3
--- a/tests/full_flow_tests.py
+++ b/tests/full_flow_tests.py
@@ -95,6 +95,20 @@ def accuracy_checker(log, expected_top1, expected_top5):
     return compare_values('Top-5', expected_top5, float(tops[-1][1]))
 
 
+def collateral_checker(log, *collateral_list):
+    """Test that the test produced the expected collaterals.
+
+    A collateral_list is a list of tuples, where tuple elements are:
+        0: file name
+        1: expected file size
+    """
+    for collateral in collateral_list:
+        statinfo = os.stat(collateral[0])
+        if statinfo.st_size != collateral[1]:
+            return False
+    return True
+
+
 ###########
 # Test Configurations
 ###########
@@ -107,7 +121,10 @@ test_configs = [
                DS_CIFAR, accuracy_checker, [91.620, 99.630]),
     TestConfig('-a preact_resnet20_cifar --epochs 2 --compress {0}'.
                format(os.path.join('full_flow_tests', 'preact_resnet20_cifar_pact_test.yaml')),
-               DS_CIFAR, accuracy_checker, [48.290, 94.460])
+               DS_CIFAR, accuracy_checker, [48.290, 94.460]),
+    TestConfig('-a resnet20_cifar --resume {0} --sense=filter --sense-range 0 0.10 0.05'.
+               format(os.path.join(examples_root, 'ssl', 'checkpoints', 'checkpoint_trained_dense.pth.tar')),
+               DS_CIFAR, collateral_checker, [('sensitivity.csv', 3188), ('sensitivity.png', 96158)])
 ]
 
 
@@ -212,4 +229,4 @@ def run_tests():
 
 
 if __name__ == '__main__':
-    run_tests()
\ No newline at end of file
+    run_tests()
-- 
GitLab