From fb44c8d007ff1f3046c0e72f922d9b002a6126b8 Mon Sep 17 00:00:00 2001
From: Guy Jacob <guy.jacob@intel.com>
Date: Fri, 22 Jun 2018 03:00:12 +0300
Subject: [PATCH] Add missing documentation and missing input check in
 MultiStepMultiGammaLR

---
 distiller/learning_rate.py  | 10 ++++++++++
 tests/test_learning_rate.py | 11 +++++++++++
 2 files changed, 21 insertions(+)

diff --git a/distiller/learning_rate.py b/distiller/learning_rate.py
index 657d018..eb5697e 100644
--- a/distiller/learning_rate.py
+++ b/distiller/learning_rate.py
@@ -41,10 +41,20 @@ class PolynomialLR(_LRScheduler):
 
 
 class MultiStepMultiGammaLR(_LRScheduler):
+    """Similar to torch.otpim.MultiStepLR, but instead of a single gamma value, specify a gamma value per-milestone.
+
+    Args:
+        optimizer (Optimizer): Wrapped optimizer.
+        milestones (list): List of epoch indices. Must be increasing.
+        gammas (list): List of gamma values. Must have same length as milestones.
+        last_epoch (int): The index of last epoch. Default: -1.
+    """
     def __init__(self, optimizer, milestones, gammas, last_epoch=-1):
         if not list(milestones) == sorted(milestones):
             raise ValueError('Milestones should be a list of'
                              ' increasing integers. Got {}', milestones)
+        if len(milestones) != len(gammas):
+            raise ValueError('Milestones and Gammas lists should be of same length.')
 
         self.milestones = milestones
         self.multiplicative_gammas = [1]
diff --git a/tests/test_learning_rate.py b/tests/test_learning_rate.py
index 6e63121..42767bb 100644
--- a/tests/test_learning_rate.py
+++ b/tests/test_learning_rate.py
@@ -16,6 +16,7 @@
 
 import os
 import sys
+import pytest
 module_path = os.path.abspath(os.path.join('..'))
 if module_path not in sys.path:
     sys.path.append(module_path)
@@ -28,6 +29,16 @@ from distiller.learning_rate import MultiStepMultiGammaLR
 def test_multi_step_multi_gamma_lr():
     dummy_tensor = torch.zeros(3, 3, 3, requires_grad=True)
     dummy_optimizer = Optimizer([dummy_tensor], {'lr': 0.1})
+
+    # Test input checks
+    with pytest.raises(ValueError):
+        lr_sched = MultiStepMultiGammaLR(dummy_optimizer, milestones=[60, 30, 80], gammas=[0.1, 0.1, 0.2])
+    with pytest.raises(ValueError):
+        lr_sched = MultiStepMultiGammaLR(dummy_optimizer, milestones=[30, 60], gammas=[0.1, 0.1, 0.2])
+    with pytest.raises(ValueError):
+        lr_sched = MultiStepMultiGammaLR(dummy_optimizer, milestones=[30, 60, 80], gammas=[0.1, 0.1])
+
+    # Test functionality
     lr_sched = MultiStepMultiGammaLR(dummy_optimizer, milestones=[30, 60, 80], gammas=[0.1, 0.1, 0.2])
     expected_gammas = [1, 1 * 0.1, 1 * 0.1 * 0.1, 1 * 0.1 * 0.1 * 0.2]
     expected_lrs = [0.1 * gamma for gamma in expected_gammas]
-- 
GitLab