Skip to content
Snippets Groups Projects
Commit d1ef1930 authored by Neta Zmora's avatar Neta Zmora
Browse files

CPU support: correct the device used for pruning masks

When masks are loaded from a checkpoint file, they should use the
same device as the model.
parent 0edfb5a9
No related branches found
No related tags found
No related merge requests found
......@@ -94,7 +94,7 @@ def load_checkpoint(model, chkpt_file, optimizer=None):
if 'compression_sched' in checkpoint:
compression_scheduler = distiller.CompressionScheduler(model)
compression_scheduler.load_state_dict(checkpoint['compression_sched'])
compression_scheduler.load_state_dict(checkpoint['compression_sched'], distiller.model_device(model))
msglogger.info("Loaded compression schedule from checkpoint (epoch %d)",
checkpoint['epoch'])
else:
......
......@@ -187,7 +187,7 @@ class CompressionScheduler(object):
state = {'masks_dict': masks}
return state
def load_state_dict(self, state):
def load_state_dict(self, state, device):
"""Loads the scheduler state.
Currently the scheduler state is comprised only of the set of pruning masks.
......@@ -210,6 +210,8 @@ class CompressionScheduler(object):
for name, mask in self.zeros_mask_dict.items():
masker = self.zeros_mask_dict[name]
masker.mask = loaded_masks[name]
if masker.mask is not None:
masker.mask = masker.mask.to(device)
@staticmethod
def verify_policy_loss(policy_loss):
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment