From fcb1ad905b560e7de2099f4d49c2b887d5430fa2 Mon Sep 17 00:00:00 2001 From: Neta Zmora <neta.zmora@intel.com> Date: Tue, 28 Apr 2020 00:20:48 +0300 Subject: [PATCH] Fix issue #176 --- distiller/scheduler.py | 18 ++++++++++++++++-- tests/test_infra.py | 24 ++++++++++++++++++++++++ 2 files changed, 40 insertions(+), 2 deletions(-) diff --git a/distiller/scheduler.py b/distiller/scheduler.py index 82266f0..763693c 100755 --- a/distiller/scheduler.py +++ b/distiller/scheduler.py @@ -42,15 +42,29 @@ class CompressionScheduler(object): # Create the masker objects and place them in a dictionary indexed by the parameter name self.zeros_mask_dict = zeros_mask_dict or create_model_masks_dict(model) - def add_policy(self, policy, epochs=None, starting_epoch=0, ending_epoch=1, frequency=1): + def add_policy(self, policy, epochs=None, starting_epoch=None, ending_epoch=None, frequency=1): """Add a new policy to the schedule. Args: - epochs (list): A list, or range, of epochs in which to apply the policy + epochs (list): A list, or range, of epochs in which to apply the policy. + starting_epoch (integer): An integer number specifying at which epoch to start. + ending_epoch (integer): An integer number specifying at which epoch to end. + frequency (integer): An integer number specifying how often to invoke the policy. + + You may only provide a list of `epochs` or a range of epochs using `starting_epoch` + and `ending_epoch` (i.e. these are mutually-exclusive) """ + assert (epochs is None and None not in (starting_epoch, ending_epoch, frequency)) or\ + (epochs is not None and all (c is None for c in (starting_epoch, ending_epoch))) if epochs is None: + assert 0 <= starting_epoch < ending_epoch + assert 0 < frequency <= (ending_epoch - starting_epoch) epochs = list(range(starting_epoch, ending_epoch, frequency)) + else: + starting_epoch = epochs[0] + ending_epoch = epochs[-1] + 1 + frequency = None for epoch in epochs: if epoch not in self.policies: diff --git a/tests/test_infra.py b/tests/test_infra.py index f47b969..e5fed03 100755 --- a/tests/test_infra.py +++ b/tests/test_infra.py @@ -298,5 +298,29 @@ def test_load_checkpoint_without_model(): os.remove(temp_checkpoint) +def test_policy_scheduling(): + model = create_model(False, 'cifar10', 'resnet20_cifar') + scheduler = distiller.CompressionScheduler(model) + policy = distiller.PruningPolicy(None, None) + with pytest.raises(AssertionError): + scheduler.add_policy(policy) + with pytest.raises(AssertionError): + # Test for mutual-exclusive configuration + scheduler.add_policy(policy, epochs=[1,2,3], starting_epoch=4, ending_epoch=5, frequency=1) + + scheduler.add_policy(policy, epochs=None, starting_epoch=4, ending_epoch=5, frequency=1) + # Regression test for issue #176 - https://github.com/NervanaSystems/distiller/issues/176 + scheduler.add_policy(policy, epochs=[1, 2, 3]) + sched_metadata = scheduler.sched_metadata[policy] + assert sched_metadata['starting_epoch'] == 1 + assert sched_metadata['ending_epoch'] == 4 + assert sched_metadata['frequency'] is None + + scheduler.add_policy(policy, epochs=[5]) + sched_metadata = scheduler.sched_metadata[policy] + assert sched_metadata['starting_epoch'] == 5 + assert sched_metadata['ending_epoch'] == 6 + assert sched_metadata['frequency'] is None + if __name__ == '__main__': test_load_gpu_model_on_cpu_with_thinning() \ No newline at end of file -- GitLab