Skip to content
Snippets Groups Projects
Commit 4cc0e7d6 authored by Neta Zmora's avatar Neta Zmora
Browse files

Fix for CPU support

parent e564a05f
No related branches found
No related tags found
No related merge requests found
......@@ -94,7 +94,7 @@ def load_checkpoint(model, chkpt_file, optimizer=None):
if 'compression_sched' in checkpoint:
compression_scheduler = distiller.CompressionScheduler(model)
compression_scheduler.load_state_dict(checkpoint['compression_sched'], distiller.model_device(model))
compression_scheduler.load_state_dict(checkpoint['compression_sched'])
msglogger.info("Loaded compression schedule from checkpoint (epoch %d)",
checkpoint['epoch'])
else:
......
......@@ -23,7 +23,7 @@ import logging
import torch
from .quantization.quantizer import FP_BKP_PREFIX
from .policy import PolicyLoss, LossComponent
from .utils import model_device
msglogger = logging.getLogger()
......@@ -187,7 +187,7 @@ class CompressionScheduler(object):
state = {'masks_dict': masks}
return state
def load_state_dict(self, state, device):
def load_state_dict(self, state):
"""Loads the scheduler state.
Currently the scheduler state is comprised only of the set of pruning masks.
......@@ -207,6 +207,7 @@ class CompressionScheduler(object):
print("\t\t" + k)
exit(1)
device = model_device(self.model)
for name, mask in self.zeros_mask_dict.items():
masker = self.zeros_mask_dict[name]
masker.mask = loaded_masks[name]
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment