Skip to content
Snippets Groups Projects
Commit fb44c8d0 authored by Guy Jacob's avatar Guy Jacob
Browse files

Add missing documentation and missing input check in MultiStepMultiGammaLR

parent 02f7871b
No related branches found
No related tags found
No related merge requests found
......@@ -41,10 +41,20 @@ class PolynomialLR(_LRScheduler):
class MultiStepMultiGammaLR(_LRScheduler):
"""Similar to torch.otpim.MultiStepLR, but instead of a single gamma value, specify a gamma value per-milestone.
Args:
optimizer (Optimizer): Wrapped optimizer.
milestones (list): List of epoch indices. Must be increasing.
gammas (list): List of gamma values. Must have same length as milestones.
last_epoch (int): The index of last epoch. Default: -1.
"""
def __init__(self, optimizer, milestones, gammas, last_epoch=-1):
if not list(milestones) == sorted(milestones):
raise ValueError('Milestones should be a list of'
' increasing integers. Got {}', milestones)
if len(milestones) != len(gammas):
raise ValueError('Milestones and Gammas lists should be of same length.')
self.milestones = milestones
self.multiplicative_gammas = [1]
......
......@@ -16,6 +16,7 @@
import os
import sys
import pytest
module_path = os.path.abspath(os.path.join('..'))
if module_path not in sys.path:
sys.path.append(module_path)
......@@ -28,6 +29,16 @@ from distiller.learning_rate import MultiStepMultiGammaLR
def test_multi_step_multi_gamma_lr():
dummy_tensor = torch.zeros(3, 3, 3, requires_grad=True)
dummy_optimizer = Optimizer([dummy_tensor], {'lr': 0.1})
# Test input checks
with pytest.raises(ValueError):
lr_sched = MultiStepMultiGammaLR(dummy_optimizer, milestones=[60, 30, 80], gammas=[0.1, 0.1, 0.2])
with pytest.raises(ValueError):
lr_sched = MultiStepMultiGammaLR(dummy_optimizer, milestones=[30, 60], gammas=[0.1, 0.1, 0.2])
with pytest.raises(ValueError):
lr_sched = MultiStepMultiGammaLR(dummy_optimizer, milestones=[30, 60, 80], gammas=[0.1, 0.1])
# Test functionality
lr_sched = MultiStepMultiGammaLR(dummy_optimizer, milestones=[30, 60, 80], gammas=[0.1, 0.1, 0.2])
expected_gammas = [1, 1 * 0.1, 1 * 0.1 * 0.1, 1 * 0.1 * 0.1 * 0.2]
expected_lrs = [0.1 * gamma for gamma in expected_gammas]
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment