Skip to content
Snippets Groups Projects
Commit d1ef1930 authored by Neta Zmora's avatar Neta Zmora
Browse files

CPU support: correct the device used for pruning masks

When masks are loaded from a checkpoint file, they should use the
same device as the model.
parent 0edfb5a9
No related branches found
No related tags found
No related merge requests found
...@@ -94,7 +94,7 @@ def load_checkpoint(model, chkpt_file, optimizer=None): ...@@ -94,7 +94,7 @@ def load_checkpoint(model, chkpt_file, optimizer=None):
if 'compression_sched' in checkpoint: if 'compression_sched' in checkpoint:
compression_scheduler = distiller.CompressionScheduler(model) compression_scheduler = distiller.CompressionScheduler(model)
compression_scheduler.load_state_dict(checkpoint['compression_sched']) compression_scheduler.load_state_dict(checkpoint['compression_sched'], distiller.model_device(model))
msglogger.info("Loaded compression schedule from checkpoint (epoch %d)", msglogger.info("Loaded compression schedule from checkpoint (epoch %d)",
checkpoint['epoch']) checkpoint['epoch'])
else: else:
......
...@@ -187,7 +187,7 @@ class CompressionScheduler(object): ...@@ -187,7 +187,7 @@ class CompressionScheduler(object):
state = {'masks_dict': masks} state = {'masks_dict': masks}
return state return state
def load_state_dict(self, state): def load_state_dict(self, state, device):
"""Loads the scheduler state. """Loads the scheduler state.
Currently the scheduler state is comprised only of the set of pruning masks. Currently the scheduler state is comprised only of the set of pruning masks.
...@@ -210,6 +210,8 @@ class CompressionScheduler(object): ...@@ -210,6 +210,8 @@ class CompressionScheduler(object):
for name, mask in self.zeros_mask_dict.items(): for name, mask in self.zeros_mask_dict.items():
masker = self.zeros_mask_dict[name] masker = self.zeros_mask_dict[name]
masker.mask = loaded_masks[name] masker.mask = loaded_masks[name]
if masker.mask is not None:
masker.mask = masker.mask.to(device)
@staticmethod @staticmethod
def verify_policy_loss(policy_loss): def verify_policy_loss(policy_loss):
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment