Skip to content
Snippets Groups Projects
Commit fcb1ad90 authored by Neta Zmora's avatar Neta Zmora
Browse files

Fix issue #176

parent f764a8aa
No related branches found
No related tags found
No related merge requests found
......@@ -42,15 +42,29 @@ class CompressionScheduler(object):
# Create the masker objects and place them in a dictionary indexed by the parameter name
self.zeros_mask_dict = zeros_mask_dict or create_model_masks_dict(model)
def add_policy(self, policy, epochs=None, starting_epoch=0, ending_epoch=1, frequency=1):
def add_policy(self, policy, epochs=None, starting_epoch=None, ending_epoch=None, frequency=1):
"""Add a new policy to the schedule.
Args:
epochs (list): A list, or range, of epochs in which to apply the policy
epochs (list): A list, or range, of epochs in which to apply the policy.
starting_epoch (integer): An integer number specifying at which epoch to start.
ending_epoch (integer): An integer number specifying at which epoch to end.
frequency (integer): An integer number specifying how often to invoke the policy.
You may only provide a list of `epochs` or a range of epochs using `starting_epoch`
and `ending_epoch` (i.e. these are mutually-exclusive)
"""
assert (epochs is None and None not in (starting_epoch, ending_epoch, frequency)) or\
(epochs is not None and all (c is None for c in (starting_epoch, ending_epoch)))
if epochs is None:
assert 0 <= starting_epoch < ending_epoch
assert 0 < frequency <= (ending_epoch - starting_epoch)
epochs = list(range(starting_epoch, ending_epoch, frequency))
else:
starting_epoch = epochs[0]
ending_epoch = epochs[-1] + 1
frequency = None
for epoch in epochs:
if epoch not in self.policies:
......
......@@ -298,5 +298,29 @@ def test_load_checkpoint_without_model():
os.remove(temp_checkpoint)
def test_policy_scheduling():
model = create_model(False, 'cifar10', 'resnet20_cifar')
scheduler = distiller.CompressionScheduler(model)
policy = distiller.PruningPolicy(None, None)
with pytest.raises(AssertionError):
scheduler.add_policy(policy)
with pytest.raises(AssertionError):
# Test for mutual-exclusive configuration
scheduler.add_policy(policy, epochs=[1,2,3], starting_epoch=4, ending_epoch=5, frequency=1)
scheduler.add_policy(policy, epochs=None, starting_epoch=4, ending_epoch=5, frequency=1)
# Regression test for issue #176 - https://github.com/NervanaSystems/distiller/issues/176
scheduler.add_policy(policy, epochs=[1, 2, 3])
sched_metadata = scheduler.sched_metadata[policy]
assert sched_metadata['starting_epoch'] == 1
assert sched_metadata['ending_epoch'] == 4
assert sched_metadata['frequency'] is None
scheduler.add_policy(policy, epochs=[5])
sched_metadata = scheduler.sched_metadata[policy]
assert sched_metadata['starting_epoch'] == 5
assert sched_metadata['ending_epoch'] == 6
assert sched_metadata['frequency'] is None
if __name__ == '__main__':
test_load_gpu_model_on_cpu_with_thinning()
\ No newline at end of file
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment