From 4cc0e7d6e2749b8d0c8014836f9ea99cf40a02df Mon Sep 17 00:00:00 2001
From: Neta Zmora <neta.zmora@intel.com>
Date: Wed, 16 Jan 2019 12:52:12 +0200
Subject: [PATCH] Fix for CPU support

---
 apputils/checkpoint.py | 2 +-
 distiller/scheduler.py | 5 +++--
 2 files changed, 4 insertions(+), 3 deletions(-)

diff --git a/apputils/checkpoint.py b/apputils/checkpoint.py
index c5c23fd..5f84883 100755
--- a/apputils/checkpoint.py
+++ b/apputils/checkpoint.py
@@ -94,7 +94,7 @@ def load_checkpoint(model, chkpt_file, optimizer=None):
 
         if 'compression_sched' in checkpoint:
             compression_scheduler = distiller.CompressionScheduler(model)
-            compression_scheduler.load_state_dict(checkpoint['compression_sched'], distiller.model_device(model))
+            compression_scheduler.load_state_dict(checkpoint['compression_sched'])
             msglogger.info("Loaded compression schedule from checkpoint (epoch %d)",
                            checkpoint['epoch'])
         else:
diff --git a/distiller/scheduler.py b/distiller/scheduler.py
index 4124f27..6e65238 100755
--- a/distiller/scheduler.py
+++ b/distiller/scheduler.py
@@ -23,7 +23,7 @@ import logging
 import torch
 from .quantization.quantizer import FP_BKP_PREFIX
 from .policy import PolicyLoss, LossComponent
-
+from .utils import model_device
 msglogger = logging.getLogger()
 
 
@@ -187,7 +187,7 @@ class CompressionScheduler(object):
         state = {'masks_dict': masks}
         return state
 
-    def load_state_dict(self, state, device):
+    def load_state_dict(self, state):
         """Loads the scheduler state.
 
         Currently the scheduler state is comprised only of the set of pruning masks.
@@ -207,6 +207,7 @@ class CompressionScheduler(object):
                 print("\t\t" + k)
             exit(1)
 
+        device = model_device(self.model)
         for name, mask in self.zeros_mask_dict.items():
             masker = self.zeros_mask_dict[name]
             masker.mask = loaded_masks[name]
-- 
GitLab