diff --git a/apputils/checkpoint.py b/apputils/checkpoint.py index c5c23fd6d42ea49209fa1dd74683a40bc055a267..5f848831bf0fb80709d3b06ddd315bd83ced64f2 100755 --- a/apputils/checkpoint.py +++ b/apputils/checkpoint.py @@ -94,7 +94,7 @@ def load_checkpoint(model, chkpt_file, optimizer=None): if 'compression_sched' in checkpoint: compression_scheduler = distiller.CompressionScheduler(model) - compression_scheduler.load_state_dict(checkpoint['compression_sched'], distiller.model_device(model)) + compression_scheduler.load_state_dict(checkpoint['compression_sched']) msglogger.info("Loaded compression schedule from checkpoint (epoch %d)", checkpoint['epoch']) else: diff --git a/distiller/scheduler.py b/distiller/scheduler.py index 4124f271ff99c75b5b01f34a2ce93419471528de..6e65238a0abc0d990207fd0e5782c3517bdddf2c 100755 --- a/distiller/scheduler.py +++ b/distiller/scheduler.py @@ -23,7 +23,7 @@ import logging import torch from .quantization.quantizer import FP_BKP_PREFIX from .policy import PolicyLoss, LossComponent - +from .utils import model_device msglogger = logging.getLogger() @@ -187,7 +187,7 @@ class CompressionScheduler(object): state = {'masks_dict': masks} return state - def load_state_dict(self, state, device): + def load_state_dict(self, state): """Loads the scheduler state. Currently the scheduler state is comprised only of the set of pruning masks. @@ -207,6 +207,7 @@ class CompressionScheduler(object): print("\t\t" + k) exit(1) + device = model_device(self.model) for name, mask in self.zeros_mask_dict.items(): masker = self.zeros_mask_dict[name] masker.mask = loaded_masks[name]