From acbb4b4d15dfbd722cbc9c5ebb36e688ef4ef02e Mon Sep 17 00:00:00 2001 From: Neta Zmora <31280975+nzmora@users.noreply.github.com> Date: Thu, 22 Nov 2018 10:58:45 +0200 Subject: [PATCH] Fix Issue 79 (#81) * Fix issue #79 Change the default values so that the following scheduler meta-data keys are always defined: 'starting_epoch', 'ending_epoch', 'frequency' * compress_classifier.py: add a new argument Allow the specification, from the command line arguments, of the range of pruning levels scanned when doing sensitivity analysis * Add regression test for issue #79 --- distiller/scheduler.py | 2 +- .../compress_classifier.py | 10 ++++++--- tests/full_flow_tests.py | 21 +++++++++++++++++-- 3 files changed, 27 insertions(+), 6 deletions(-) mode change 100644 => 100755 tests/full_flow_tests.py diff --git a/distiller/scheduler.py b/distiller/scheduler.py index 11b7e69..8b26eeb 100755 --- a/distiller/scheduler.py +++ b/distiller/scheduler.py @@ -83,7 +83,7 @@ class CompressionScheduler(object): masker = ParameterMasker(name) self.zeros_mask_dict[name] = masker - def add_policy(self, policy, epochs=None, starting_epoch=None, ending_epoch=None, frequency=None): + def add_policy(self, policy, epochs=None, starting_epoch=0, ending_epoch=1, frequency=1): """Add a new policy to the schedule. Args: diff --git a/examples/classifier_compression/compress_classifier.py b/examples/classifier_compression/compress_classifier.py index 7d70958..d56c50f 100755 --- a/examples/classifier_compression/compress_classifier.py +++ b/examples/classifier_compression/compress_classifier.py @@ -132,6 +132,9 @@ parser.add_argument('--compress', dest='compress', type=str, nargs='?', action=' help='configuration file for pruning the model (default is to use hard-coded schedule)') parser.add_argument('--sense', dest='sensitivity', choices=['element', 'filter', 'channel'], help='test the sensitivity of layers to pruning') +parser.add_argument('--sense-range', dest='sensitivity_range', type=float, nargs=3, default=[0.0, 0.95, 0.05], + help='an optional paramaeter for sensitivity testing providing the range of sparsities to test.\n' + 'This is equaivalent to creating sensitivities = np.arange(start, stop, step)') parser.add_argument('--extras', default=None, type=str, help='file with extra configuration information') parser.add_argument('--deterministic', '--det', action='store_true', @@ -338,7 +341,8 @@ def main(): activations_collectors = create_activation_stats_collectors(model, collection_phase=args.activation_stats) if args.sensitivity is not None: - return sensitivity_analysis(model, criterion, test_loader, pylogger, args) + sensitivities = np.arange(args.sensitivity_range[0], args.sensitivity_range[1], args.sensitivity_range[2]) + return sensitivity_analysis(model, criterion, test_loader, pylogger, args, sensitivities) if args.evaluate: return evaluate_model(model, criterion, test_loader, pylogger, activations_collectors, args) @@ -735,7 +739,7 @@ def summarize_model(model, dataset, which_summary): distiller.model_summary(model, which_summary, dataset) -def sensitivity_analysis(model, criterion, data_loader, loggers, args): +def sensitivity_analysis(model, criterion, data_loader, loggers, args, sparsities): # This sample application can be invoked to execute Sensitivity Analysis on your # model. The ouptut is saved to CSV and PNG. msglogger.info("Running sensitivity tests") @@ -747,7 +751,7 @@ def sensitivity_analysis(model, criterion, data_loader, loggers, args): which_params = [param_name for param_name, _ in model.named_parameters()] sensitivity = distiller.perform_sensitivity_analysis(model, net_params=which_params, - sparsities=np.arange(0.0, 0.95, 0.05), + sparsities=sparsities, test_func=test_fnc, group=args.sensitivity) distiller.sensitivities_to_png(sensitivity, 'sensitivity.png') diff --git a/tests/full_flow_tests.py b/tests/full_flow_tests.py old mode 100644 new mode 100755 index 889511f..f23c1c3 --- a/tests/full_flow_tests.py +++ b/tests/full_flow_tests.py @@ -95,6 +95,20 @@ def accuracy_checker(log, expected_top1, expected_top5): return compare_values('Top-5', expected_top5, float(tops[-1][1])) +def collateral_checker(log, *collateral_list): + """Test that the test produced the expected collaterals. + + A collateral_list is a list of tuples, where tuple elements are: + 0: file name + 1: expected file size + """ + for collateral in collateral_list: + statinfo = os.stat(collateral[0]) + if statinfo.st_size != collateral[1]: + return False + return True + + ########### # Test Configurations ########### @@ -107,7 +121,10 @@ test_configs = [ DS_CIFAR, accuracy_checker, [91.620, 99.630]), TestConfig('-a preact_resnet20_cifar --epochs 2 --compress {0}'. format(os.path.join('full_flow_tests', 'preact_resnet20_cifar_pact_test.yaml')), - DS_CIFAR, accuracy_checker, [48.290, 94.460]) + DS_CIFAR, accuracy_checker, [48.290, 94.460]), + TestConfig('-a resnet20_cifar --resume {0} --sense=filter --sense-range 0 0.10 0.05'. + format(os.path.join(examples_root, 'ssl', 'checkpoints', 'checkpoint_trained_dense.pth.tar')), + DS_CIFAR, collateral_checker, [('sensitivity.csv', 3188), ('sensitivity.png', 96158)]) ] @@ -212,4 +229,4 @@ def run_tests(): if __name__ == '__main__': - run_tests() \ No newline at end of file + run_tests() -- GitLab