From d1ef193014de33d97c2216e1246f67ade2d2c989 Mon Sep 17 00:00:00 2001
From: Neta Zmora <neta.zmora@intel.com>
Date: Sun, 13 Jan 2019 18:13:13 +0200
Subject: [PATCH] CPU support: correct the device used for pruning masks

When masks are loaded from a checkpoint file, they should use the
same device as the model.
---
 apputils/checkpoint.py | 2 +-
 distiller/scheduler.py | 4 +++-
 2 files changed, 4 insertions(+), 2 deletions(-)

diff --git a/apputils/checkpoint.py b/apputils/checkpoint.py
index 5f84883..c5c23fd 100755
--- a/apputils/checkpoint.py
+++ b/apputils/checkpoint.py
@@ -94,7 +94,7 @@ def load_checkpoint(model, chkpt_file, optimizer=None):
 
         if 'compression_sched' in checkpoint:
             compression_scheduler = distiller.CompressionScheduler(model)
-            compression_scheduler.load_state_dict(checkpoint['compression_sched'])
+            compression_scheduler.load_state_dict(checkpoint['compression_sched'], distiller.model_device(model))
             msglogger.info("Loaded compression schedule from checkpoint (epoch %d)",
                            checkpoint['epoch'])
         else:
diff --git a/distiller/scheduler.py b/distiller/scheduler.py
index 8b26eeb..4124f27 100755
--- a/distiller/scheduler.py
+++ b/distiller/scheduler.py
@@ -187,7 +187,7 @@ class CompressionScheduler(object):
         state = {'masks_dict': masks}
         return state
 
-    def load_state_dict(self, state):
+    def load_state_dict(self, state, device):
         """Loads the scheduler state.
 
         Currently the scheduler state is comprised only of the set of pruning masks.
@@ -210,6 +210,8 @@ class CompressionScheduler(object):
         for name, mask in self.zeros_mask_dict.items():
             masker = self.zeros_mask_dict[name]
             masker.mask = loaded_masks[name]
+            if masker.mask is not None:
+                masker.mask = masker.mask.to(device)
 
     @staticmethod
     def verify_policy_loss(policy_loss):
-- 
GitLab