From 4cc0e7d6e2749b8d0c8014836f9ea99cf40a02df Mon Sep 17 00:00:00 2001 From: Neta Zmora <neta.zmora@intel.com> Date: Wed, 16 Jan 2019 12:52:12 +0200 Subject: [PATCH] Fix for CPU support --- apputils/checkpoint.py | 2 +- distiller/scheduler.py | 5 +++-- 2 files changed, 4 insertions(+), 3 deletions(-) diff --git a/apputils/checkpoint.py b/apputils/checkpoint.py index c5c23fd..5f84883 100755 --- a/apputils/checkpoint.py +++ b/apputils/checkpoint.py @@ -94,7 +94,7 @@ def load_checkpoint(model, chkpt_file, optimizer=None): if 'compression_sched' in checkpoint: compression_scheduler = distiller.CompressionScheduler(model) - compression_scheduler.load_state_dict(checkpoint['compression_sched'], distiller.model_device(model)) + compression_scheduler.load_state_dict(checkpoint['compression_sched']) msglogger.info("Loaded compression schedule from checkpoint (epoch %d)", checkpoint['epoch']) else: diff --git a/distiller/scheduler.py b/distiller/scheduler.py index 4124f27..6e65238 100755 --- a/distiller/scheduler.py +++ b/distiller/scheduler.py @@ -23,7 +23,7 @@ import logging import torch from .quantization.quantizer import FP_BKP_PREFIX from .policy import PolicyLoss, LossComponent - +from .utils import model_device msglogger = logging.getLogger() @@ -187,7 +187,7 @@ class CompressionScheduler(object): state = {'masks_dict': masks} return state - def load_state_dict(self, state, device): + def load_state_dict(self, state): """Loads the scheduler state. Currently the scheduler state is comprised only of the set of pruning masks. @@ -207,6 +207,7 @@ class CompressionScheduler(object): print("\t\t" + k) exit(1) + device = model_device(self.model) for name, mask in self.zeros_mask_dict.items(): masker = self.zeros_mask_dict[name] masker.mask = loaded_masks[name] -- GitLab