Skip to content
Snippets Groups Projects
Unverified Commit acbb4b4d authored by Neta Zmora's avatar Neta Zmora Committed by GitHub
Browse files

Fix Issue 79 (#81)

* Fix issue #79

Change the default values so that the following scheduler meta-data keys
are always defined: 'starting_epoch', 'ending_epoch', 'frequency'

* compress_classifier.py: add a new argument

Allow the specification, from the command line arguments,  of the range of
pruning levels scanned when doing sensitivity analysis

* Add regression test for issue #79
parent 485cc421
No related branches found
No related tags found
No related merge requests found
...@@ -83,7 +83,7 @@ class CompressionScheduler(object): ...@@ -83,7 +83,7 @@ class CompressionScheduler(object):
masker = ParameterMasker(name) masker = ParameterMasker(name)
self.zeros_mask_dict[name] = masker self.zeros_mask_dict[name] = masker
def add_policy(self, policy, epochs=None, starting_epoch=None, ending_epoch=None, frequency=None): def add_policy(self, policy, epochs=None, starting_epoch=0, ending_epoch=1, frequency=1):
"""Add a new policy to the schedule. """Add a new policy to the schedule.
Args: Args:
......
...@@ -132,6 +132,9 @@ parser.add_argument('--compress', dest='compress', type=str, nargs='?', action=' ...@@ -132,6 +132,9 @@ parser.add_argument('--compress', dest='compress', type=str, nargs='?', action='
help='configuration file for pruning the model (default is to use hard-coded schedule)') help='configuration file for pruning the model (default is to use hard-coded schedule)')
parser.add_argument('--sense', dest='sensitivity', choices=['element', 'filter', 'channel'], parser.add_argument('--sense', dest='sensitivity', choices=['element', 'filter', 'channel'],
help='test the sensitivity of layers to pruning') help='test the sensitivity of layers to pruning')
parser.add_argument('--sense-range', dest='sensitivity_range', type=float, nargs=3, default=[0.0, 0.95, 0.05],
help='an optional paramaeter for sensitivity testing providing the range of sparsities to test.\n'
'This is equaivalent to creating sensitivities = np.arange(start, stop, step)')
parser.add_argument('--extras', default=None, type=str, parser.add_argument('--extras', default=None, type=str,
help='file with extra configuration information') help='file with extra configuration information')
parser.add_argument('--deterministic', '--det', action='store_true', parser.add_argument('--deterministic', '--det', action='store_true',
...@@ -338,7 +341,8 @@ def main(): ...@@ -338,7 +341,8 @@ def main():
activations_collectors = create_activation_stats_collectors(model, collection_phase=args.activation_stats) activations_collectors = create_activation_stats_collectors(model, collection_phase=args.activation_stats)
if args.sensitivity is not None: if args.sensitivity is not None:
return sensitivity_analysis(model, criterion, test_loader, pylogger, args) sensitivities = np.arange(args.sensitivity_range[0], args.sensitivity_range[1], args.sensitivity_range[2])
return sensitivity_analysis(model, criterion, test_loader, pylogger, args, sensitivities)
if args.evaluate: if args.evaluate:
return evaluate_model(model, criterion, test_loader, pylogger, activations_collectors, args) return evaluate_model(model, criterion, test_loader, pylogger, activations_collectors, args)
...@@ -735,7 +739,7 @@ def summarize_model(model, dataset, which_summary): ...@@ -735,7 +739,7 @@ def summarize_model(model, dataset, which_summary):
distiller.model_summary(model, which_summary, dataset) distiller.model_summary(model, which_summary, dataset)
def sensitivity_analysis(model, criterion, data_loader, loggers, args): def sensitivity_analysis(model, criterion, data_loader, loggers, args, sparsities):
# This sample application can be invoked to execute Sensitivity Analysis on your # This sample application can be invoked to execute Sensitivity Analysis on your
# model. The ouptut is saved to CSV and PNG. # model. The ouptut is saved to CSV and PNG.
msglogger.info("Running sensitivity tests") msglogger.info("Running sensitivity tests")
...@@ -747,7 +751,7 @@ def sensitivity_analysis(model, criterion, data_loader, loggers, args): ...@@ -747,7 +751,7 @@ def sensitivity_analysis(model, criterion, data_loader, loggers, args):
which_params = [param_name for param_name, _ in model.named_parameters()] which_params = [param_name for param_name, _ in model.named_parameters()]
sensitivity = distiller.perform_sensitivity_analysis(model, sensitivity = distiller.perform_sensitivity_analysis(model,
net_params=which_params, net_params=which_params,
sparsities=np.arange(0.0, 0.95, 0.05), sparsities=sparsities,
test_func=test_fnc, test_func=test_fnc,
group=args.sensitivity) group=args.sensitivity)
distiller.sensitivities_to_png(sensitivity, 'sensitivity.png') distiller.sensitivities_to_png(sensitivity, 'sensitivity.png')
......
...@@ -95,6 +95,20 @@ def accuracy_checker(log, expected_top1, expected_top5): ...@@ -95,6 +95,20 @@ def accuracy_checker(log, expected_top1, expected_top5):
return compare_values('Top-5', expected_top5, float(tops[-1][1])) return compare_values('Top-5', expected_top5, float(tops[-1][1]))
def collateral_checker(log, *collateral_list):
"""Test that the test produced the expected collaterals.
A collateral_list is a list of tuples, where tuple elements are:
0: file name
1: expected file size
"""
for collateral in collateral_list:
statinfo = os.stat(collateral[0])
if statinfo.st_size != collateral[1]:
return False
return True
########### ###########
# Test Configurations # Test Configurations
########### ###########
...@@ -107,7 +121,10 @@ test_configs = [ ...@@ -107,7 +121,10 @@ test_configs = [
DS_CIFAR, accuracy_checker, [91.620, 99.630]), DS_CIFAR, accuracy_checker, [91.620, 99.630]),
TestConfig('-a preact_resnet20_cifar --epochs 2 --compress {0}'. TestConfig('-a preact_resnet20_cifar --epochs 2 --compress {0}'.
format(os.path.join('full_flow_tests', 'preact_resnet20_cifar_pact_test.yaml')), format(os.path.join('full_flow_tests', 'preact_resnet20_cifar_pact_test.yaml')),
DS_CIFAR, accuracy_checker, [48.290, 94.460]) DS_CIFAR, accuracy_checker, [48.290, 94.460]),
TestConfig('-a resnet20_cifar --resume {0} --sense=filter --sense-range 0 0.10 0.05'.
format(os.path.join(examples_root, 'ssl', 'checkpoints', 'checkpoint_trained_dense.pth.tar')),
DS_CIFAR, collateral_checker, [('sensitivity.csv', 3188), ('sensitivity.png', 96158)])
] ]
...@@ -212,4 +229,4 @@ def run_tests(): ...@@ -212,4 +229,4 @@ def run_tests():
if __name__ == '__main__': if __name__ == '__main__':
run_tests() run_tests()
\ No newline at end of file
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment