From fcb1ad905b560e7de2099f4d49c2b887d5430fa2 Mon Sep 17 00:00:00 2001
From: Neta Zmora <neta.zmora@intel.com>
Date: Tue, 28 Apr 2020 00:20:48 +0300
Subject: [PATCH] Fix issue #176

---
 distiller/scheduler.py | 18 ++++++++++++++++--
 tests/test_infra.py    | 24 ++++++++++++++++++++++++
 2 files changed, 40 insertions(+), 2 deletions(-)

diff --git a/distiller/scheduler.py b/distiller/scheduler.py
index 82266f0..763693c 100755
--- a/distiller/scheduler.py
+++ b/distiller/scheduler.py
@@ -42,15 +42,29 @@ class CompressionScheduler(object):
         # Create the masker objects and place them in a dictionary indexed by the parameter name
         self.zeros_mask_dict = zeros_mask_dict or create_model_masks_dict(model)
 
-    def add_policy(self, policy, epochs=None, starting_epoch=0, ending_epoch=1, frequency=1):
+    def add_policy(self, policy, epochs=None, starting_epoch=None, ending_epoch=None, frequency=1):
         """Add a new policy to the schedule.
 
         Args:
-            epochs (list): A list, or range, of epochs in which to apply the policy
+            epochs (list): A list, or range, of epochs in which to apply the policy.
+            starting_epoch (integer): An integer number specifying at which epoch to start.
+            ending_epoch (integer): An integer number specifying at which epoch to end.
+            frequency (integer): An integer number specifying how often to invoke the policy.
+
+            You may only provide a list of `epochs` or a range of epochs using `starting_epoch`
+            and `ending_epoch` (i.e. these are mutually-exclusive)
         """
+        assert (epochs is None and None not in (starting_epoch, ending_epoch, frequency)) or\
+               (epochs is not None and all (c is None for c in (starting_epoch, ending_epoch)))
 
         if epochs is None:
+            assert 0 <= starting_epoch < ending_epoch
+            assert 0 < frequency <= (ending_epoch - starting_epoch)
             epochs = list(range(starting_epoch, ending_epoch, frequency))
+        else:
+            starting_epoch = epochs[0]
+            ending_epoch = epochs[-1] + 1
+            frequency = None
 
         for epoch in epochs:
             if epoch not in self.policies:
diff --git a/tests/test_infra.py b/tests/test_infra.py
index f47b969..e5fed03 100755
--- a/tests/test_infra.py
+++ b/tests/test_infra.py
@@ -298,5 +298,29 @@ def test_load_checkpoint_without_model():
         os.remove(temp_checkpoint)
 
 
+def test_policy_scheduling():
+    model = create_model(False, 'cifar10', 'resnet20_cifar')
+    scheduler = distiller.CompressionScheduler(model)
+    policy = distiller.PruningPolicy(None, None)
+    with pytest.raises(AssertionError):
+        scheduler.add_policy(policy)
+    with pytest.raises(AssertionError):
+        # Test for mutual-exclusive configuration
+        scheduler.add_policy(policy, epochs=[1,2,3], starting_epoch=4, ending_epoch=5, frequency=1)
+
+    scheduler.add_policy(policy, epochs=None, starting_epoch=4, ending_epoch=5, frequency=1)
+    # Regression test for issue #176 - https://github.com/NervanaSystems/distiller/issues/176
+    scheduler.add_policy(policy, epochs=[1, 2, 3])
+    sched_metadata = scheduler.sched_metadata[policy]
+    assert sched_metadata['starting_epoch'] == 1
+    assert sched_metadata['ending_epoch'] == 4
+    assert sched_metadata['frequency'] is None
+
+    scheduler.add_policy(policy, epochs=[5])
+    sched_metadata = scheduler.sched_metadata[policy]
+    assert sched_metadata['starting_epoch'] == 5
+    assert sched_metadata['ending_epoch'] == 6
+    assert sched_metadata['frequency'] is None
+
 if __name__ == '__main__':
     test_load_gpu_model_on_cpu_with_thinning()
\ No newline at end of file
-- 
GitLab