diff --git a/apputils/checkpoint.py b/apputils/checkpoint.py index 5f848831bf0fb80709d3b06ddd315bd83ced64f2..c5c23fd6d42ea49209fa1dd74683a40bc055a267 100755 --- a/apputils/checkpoint.py +++ b/apputils/checkpoint.py @@ -94,7 +94,7 @@ def load_checkpoint(model, chkpt_file, optimizer=None): if 'compression_sched' in checkpoint: compression_scheduler = distiller.CompressionScheduler(model) - compression_scheduler.load_state_dict(checkpoint['compression_sched']) + compression_scheduler.load_state_dict(checkpoint['compression_sched'], distiller.model_device(model)) msglogger.info("Loaded compression schedule from checkpoint (epoch %d)", checkpoint['epoch']) else: diff --git a/distiller/scheduler.py b/distiller/scheduler.py index 8b26eeb44c73d9769b1a726cab272784a6d0add0..4124f271ff99c75b5b01f34a2ce93419471528de 100755 --- a/distiller/scheduler.py +++ b/distiller/scheduler.py @@ -187,7 +187,7 @@ class CompressionScheduler(object): state = {'masks_dict': masks} return state - def load_state_dict(self, state): + def load_state_dict(self, state, device): """Loads the scheduler state. Currently the scheduler state is comprised only of the set of pruning masks. @@ -210,6 +210,8 @@ class CompressionScheduler(object): for name, mask in self.zeros_mask_dict.items(): masker = self.zeros_mask_dict[name] masker.mask = loaded_masks[name] + if masker.mask is not None: + masker.mask = masker.mask.to(device) @staticmethod def verify_policy_loss(policy_loss):