From fb44c8d007ff1f3046c0e72f922d9b002a6126b8 Mon Sep 17 00:00:00 2001 From: Guy Jacob <guy.jacob@intel.com> Date: Fri, 22 Jun 2018 03:00:12 +0300 Subject: [PATCH] Add missing documentation and missing input check in MultiStepMultiGammaLR --- distiller/learning_rate.py | 10 ++++++++++ tests/test_learning_rate.py | 11 +++++++++++ 2 files changed, 21 insertions(+) diff --git a/distiller/learning_rate.py b/distiller/learning_rate.py index 657d018..eb5697e 100644 --- a/distiller/learning_rate.py +++ b/distiller/learning_rate.py @@ -41,10 +41,20 @@ class PolynomialLR(_LRScheduler): class MultiStepMultiGammaLR(_LRScheduler): + """Similar to torch.otpim.MultiStepLR, but instead of a single gamma value, specify a gamma value per-milestone. + + Args: + optimizer (Optimizer): Wrapped optimizer. + milestones (list): List of epoch indices. Must be increasing. + gammas (list): List of gamma values. Must have same length as milestones. + last_epoch (int): The index of last epoch. Default: -1. + """ def __init__(self, optimizer, milestones, gammas, last_epoch=-1): if not list(milestones) == sorted(milestones): raise ValueError('Milestones should be a list of' ' increasing integers. Got {}', milestones) + if len(milestones) != len(gammas): + raise ValueError('Milestones and Gammas lists should be of same length.') self.milestones = milestones self.multiplicative_gammas = [1] diff --git a/tests/test_learning_rate.py b/tests/test_learning_rate.py index 6e63121..42767bb 100644 --- a/tests/test_learning_rate.py +++ b/tests/test_learning_rate.py @@ -16,6 +16,7 @@ import os import sys +import pytest module_path = os.path.abspath(os.path.join('..')) if module_path not in sys.path: sys.path.append(module_path) @@ -28,6 +29,16 @@ from distiller.learning_rate import MultiStepMultiGammaLR def test_multi_step_multi_gamma_lr(): dummy_tensor = torch.zeros(3, 3, 3, requires_grad=True) dummy_optimizer = Optimizer([dummy_tensor], {'lr': 0.1}) + + # Test input checks + with pytest.raises(ValueError): + lr_sched = MultiStepMultiGammaLR(dummy_optimizer, milestones=[60, 30, 80], gammas=[0.1, 0.1, 0.2]) + with pytest.raises(ValueError): + lr_sched = MultiStepMultiGammaLR(dummy_optimizer, milestones=[30, 60], gammas=[0.1, 0.1, 0.2]) + with pytest.raises(ValueError): + lr_sched = MultiStepMultiGammaLR(dummy_optimizer, milestones=[30, 60, 80], gammas=[0.1, 0.1]) + + # Test functionality lr_sched = MultiStepMultiGammaLR(dummy_optimizer, milestones=[30, 60, 80], gammas=[0.1, 0.1, 0.2]) expected_gammas = [1, 1 * 0.1, 1 * 0.1 * 0.1, 1 * 0.1 * 0.1 * 0.2] expected_lrs = [0.1 * gamma for gamma in expected_gammas] -- GitLab