diff --git a/distiller/learning_rate.py b/distiller/learning_rate.py
index 657d0181c40d331a20d2d1ff23792ef74e904aed..eb5697ee862e692f20a87144141727855a666fb2 100644
--- a/distiller/learning_rate.py
+++ b/distiller/learning_rate.py
@@ -41,10 +41,20 @@ class PolynomialLR(_LRScheduler):
 
 
 class MultiStepMultiGammaLR(_LRScheduler):
+    """Similar to torch.otpim.MultiStepLR, but instead of a single gamma value, specify a gamma value per-milestone.
+
+    Args:
+        optimizer (Optimizer): Wrapped optimizer.
+        milestones (list): List of epoch indices. Must be increasing.
+        gammas (list): List of gamma values. Must have same length as milestones.
+        last_epoch (int): The index of last epoch. Default: -1.
+    """
     def __init__(self, optimizer, milestones, gammas, last_epoch=-1):
         if not list(milestones) == sorted(milestones):
             raise ValueError('Milestones should be a list of'
                              ' increasing integers. Got {}', milestones)
+        if len(milestones) != len(gammas):
+            raise ValueError('Milestones and Gammas lists should be of same length.')
 
         self.milestones = milestones
         self.multiplicative_gammas = [1]
diff --git a/tests/test_learning_rate.py b/tests/test_learning_rate.py
index 6e63121c0122e246194738e14aa34eafa5ac05af..42767bb6601e540a037e33680921513fc5c7d187 100644
--- a/tests/test_learning_rate.py
+++ b/tests/test_learning_rate.py
@@ -16,6 +16,7 @@
 
 import os
 import sys
+import pytest
 module_path = os.path.abspath(os.path.join('..'))
 if module_path not in sys.path:
     sys.path.append(module_path)
@@ -28,6 +29,16 @@ from distiller.learning_rate import MultiStepMultiGammaLR
 def test_multi_step_multi_gamma_lr():
     dummy_tensor = torch.zeros(3, 3, 3, requires_grad=True)
     dummy_optimizer = Optimizer([dummy_tensor], {'lr': 0.1})
+
+    # Test input checks
+    with pytest.raises(ValueError):
+        lr_sched = MultiStepMultiGammaLR(dummy_optimizer, milestones=[60, 30, 80], gammas=[0.1, 0.1, 0.2])
+    with pytest.raises(ValueError):
+        lr_sched = MultiStepMultiGammaLR(dummy_optimizer, milestones=[30, 60], gammas=[0.1, 0.1, 0.2])
+    with pytest.raises(ValueError):
+        lr_sched = MultiStepMultiGammaLR(dummy_optimizer, milestones=[30, 60, 80], gammas=[0.1, 0.1])
+
+    # Test functionality
     lr_sched = MultiStepMultiGammaLR(dummy_optimizer, milestones=[30, 60, 80], gammas=[0.1, 0.1, 0.2])
     expected_gammas = [1, 1 * 0.1, 1 * 0.1 * 0.1, 1 * 0.1 * 0.1 * 0.2]
     expected_lrs = [0.1 * gamma for gamma in expected_gammas]