diff --git a/apputils/checkpoint.py b/apputils/checkpoint.py
index 5f848831bf0fb80709d3b06ddd315bd83ced64f2..c5c23fd6d42ea49209fa1dd74683a40bc055a267 100755
--- a/apputils/checkpoint.py
+++ b/apputils/checkpoint.py
@@ -94,7 +94,7 @@ def load_checkpoint(model, chkpt_file, optimizer=None):
 
         if 'compression_sched' in checkpoint:
             compression_scheduler = distiller.CompressionScheduler(model)
-            compression_scheduler.load_state_dict(checkpoint['compression_sched'])
+            compression_scheduler.load_state_dict(checkpoint['compression_sched'], distiller.model_device(model))
             msglogger.info("Loaded compression schedule from checkpoint (epoch %d)",
                            checkpoint['epoch'])
         else:
diff --git a/distiller/scheduler.py b/distiller/scheduler.py
index 8b26eeb44c73d9769b1a726cab272784a6d0add0..4124f271ff99c75b5b01f34a2ce93419471528de 100755
--- a/distiller/scheduler.py
+++ b/distiller/scheduler.py
@@ -187,7 +187,7 @@ class CompressionScheduler(object):
         state = {'masks_dict': masks}
         return state
 
-    def load_state_dict(self, state):
+    def load_state_dict(self, state, device):
         """Loads the scheduler state.
 
         Currently the scheduler state is comprised only of the set of pruning masks.
@@ -210,6 +210,8 @@ class CompressionScheduler(object):
         for name, mask in self.zeros_mask_dict.items():
             masker = self.zeros_mask_dict[name]
             masker.mask = loaded_masks[name]
+            if masker.mask is not None:
+                masker.mask = masker.mask.to(device)
 
     @staticmethod
     def verify_policy_loss(policy_loss):