diff --git a/distiller/scheduler.py b/distiller/scheduler.py index a5e360d01d3e7dfa0bb861d6d9a87fda83ed1fa6..a24de1a4d064713818732ed6ae4d703b7285f2fa 100755 --- a/distiller/scheduler.py +++ b/distiller/scheduler.py @@ -189,7 +189,7 @@ class CompressionScheduler(object): state = {'masks_dict': masks} return state - def load_state_dict(self, state, normalize_dataparallel_keys): + def load_state_dict(self, state, normalize_dataparallel_keys=False): """Loads the scheduler state. Currently the scheduler state is comprised only of the set of pruning masks.