From e564a05f47a9e15e8575615d1ba92358b9184b67 Mon Sep 17 00:00:00 2001
From: Neta Zmora <neta.zmora@intel.com>
Date: Wed, 16 Jan 2019 01:15:48 +0200
Subject: [PATCH] CPU support: fix thinning directive tensor migration to
 CPU/GPU

---
 distiller/thinning.py | 4 ++--
 1 file changed, 2 insertions(+), 2 deletions(-)

diff --git a/distiller/thinning.py b/distiller/thinning.py
index 5cdff7f..58ceb1c 100755
--- a/distiller/thinning.py
+++ b/distiller/thinning.py
@@ -474,7 +474,7 @@ def execute_thinning_recipe(model, zeros_mask_dict, recipe, optimizer, loaded_fr
     This will remove filters and channels, as well as handle batch-normalization parameter
     adjustment, and thinning of weight tensors.
     """
-
+    device = distiller.utils.model_device(model)
     layers = {mod_name: m for mod_name, m in model.named_modules()}
     for layer_name, directives in recipe.modules.items():
         for attr, val in directives.items():
@@ -500,7 +500,7 @@ def execute_thinning_recipe(model, zeros_mask_dict, recipe, optimizer, loaded_fr
         assert param is not None
         for directive in param_directives:
             dim = directive[0]
-            indices = directive[1]
+            indices = directive[1].to(device)
             len_indices = indices.nelement()
             if len(directive) == 4:  # TODO: this code is hard to follow
                 msglogger.debug("{}-{}-{}: SHAPE = {}".format(param_name, param.shape, id(param), list(directive[2])))
-- 
GitLab