diff --git a/distiller/scheduler.py b/distiller/scheduler.py index 82266f04a0aa5f0e9f2e1af671f2dda1a4ffec89..763693c9abec5f414fe7f7a86d0187e46012014d 100755 --- a/distiller/scheduler.py +++ b/distiller/scheduler.py @@ -42,15 +42,29 @@ class CompressionScheduler(object): # Create the masker objects and place them in a dictionary indexed by the parameter name self.zeros_mask_dict = zeros_mask_dict or create_model_masks_dict(model) - def add_policy(self, policy, epochs=None, starting_epoch=0, ending_epoch=1, frequency=1): + def add_policy(self, policy, epochs=None, starting_epoch=None, ending_epoch=None, frequency=1): """Add a new policy to the schedule. Args: - epochs (list): A list, or range, of epochs in which to apply the policy + epochs (list): A list, or range, of epochs in which to apply the policy. + starting_epoch (integer): An integer number specifying at which epoch to start. + ending_epoch (integer): An integer number specifying at which epoch to end. + frequency (integer): An integer number specifying how often to invoke the policy. + + You may only provide a list of `epochs` or a range of epochs using `starting_epoch` + and `ending_epoch` (i.e. these are mutually-exclusive) """ + assert (epochs is None and None not in (starting_epoch, ending_epoch, frequency)) or\ + (epochs is not None and all (c is None for c in (starting_epoch, ending_epoch))) if epochs is None: + assert 0 <= starting_epoch < ending_epoch + assert 0 < frequency <= (ending_epoch - starting_epoch) epochs = list(range(starting_epoch, ending_epoch, frequency)) + else: + starting_epoch = epochs[0] + ending_epoch = epochs[-1] + 1 + frequency = None for epoch in epochs: if epoch not in self.policies: diff --git a/tests/test_infra.py b/tests/test_infra.py index f47b96999f321314a2a0efc227e726fd5af791a7..e5fed03ac40ae299054156ad6006167519dea6d7 100755 --- a/tests/test_infra.py +++ b/tests/test_infra.py @@ -298,5 +298,29 @@ def test_load_checkpoint_without_model(): os.remove(temp_checkpoint) +def test_policy_scheduling(): + model = create_model(False, 'cifar10', 'resnet20_cifar') + scheduler = distiller.CompressionScheduler(model) + policy = distiller.PruningPolicy(None, None) + with pytest.raises(AssertionError): + scheduler.add_policy(policy) + with pytest.raises(AssertionError): + # Test for mutual-exclusive configuration + scheduler.add_policy(policy, epochs=[1,2,3], starting_epoch=4, ending_epoch=5, frequency=1) + + scheduler.add_policy(policy, epochs=None, starting_epoch=4, ending_epoch=5, frequency=1) + # Regression test for issue #176 - https://github.com/NervanaSystems/distiller/issues/176 + scheduler.add_policy(policy, epochs=[1, 2, 3]) + sched_metadata = scheduler.sched_metadata[policy] + assert sched_metadata['starting_epoch'] == 1 + assert sched_metadata['ending_epoch'] == 4 + assert sched_metadata['frequency'] is None + + scheduler.add_policy(policy, epochs=[5]) + sched_metadata = scheduler.sched_metadata[policy] + assert sched_metadata['starting_epoch'] == 5 + assert sched_metadata['ending_epoch'] == 6 + assert sched_metadata['frequency'] is None + if __name__ == '__main__': test_load_gpu_model_on_cpu_with_thinning() \ No newline at end of file