From d1ef193014de33d97c2216e1246f67ade2d2c989 Mon Sep 17 00:00:00 2001 From: Neta Zmora <neta.zmora@intel.com> Date: Sun, 13 Jan 2019 18:13:13 +0200 Subject: [PATCH] CPU support: correct the device used for pruning masks When masks are loaded from a checkpoint file, they should use the same device as the model. --- apputils/checkpoint.py | 2 +- distiller/scheduler.py | 4 +++- 2 files changed, 4 insertions(+), 2 deletions(-) diff --git a/apputils/checkpoint.py b/apputils/checkpoint.py index 5f84883..c5c23fd 100755 --- a/apputils/checkpoint.py +++ b/apputils/checkpoint.py @@ -94,7 +94,7 @@ def load_checkpoint(model, chkpt_file, optimizer=None): if 'compression_sched' in checkpoint: compression_scheduler = distiller.CompressionScheduler(model) - compression_scheduler.load_state_dict(checkpoint['compression_sched']) + compression_scheduler.load_state_dict(checkpoint['compression_sched'], distiller.model_device(model)) msglogger.info("Loaded compression schedule from checkpoint (epoch %d)", checkpoint['epoch']) else: diff --git a/distiller/scheduler.py b/distiller/scheduler.py index 8b26eeb..4124f27 100755 --- a/distiller/scheduler.py +++ b/distiller/scheduler.py @@ -187,7 +187,7 @@ class CompressionScheduler(object): state = {'masks_dict': masks} return state - def load_state_dict(self, state): + def load_state_dict(self, state, device): """Loads the scheduler state. Currently the scheduler state is comprised only of the set of pruning masks. @@ -210,6 +210,8 @@ class CompressionScheduler(object): for name, mask in self.zeros_mask_dict.items(): masker = self.zeros_mask_dict[name] masker.mask = loaded_masks[name] + if masker.mask is not None: + masker.mask = masker.mask.to(device) @staticmethod def verify_policy_loss(policy_loss): -- GitLab