diff --git a/distiller/thinning.py b/distiller/thinning.py index 5cdff7f2ff14491073dfa3429d4978a3d793f11d..58ceb1c20b31372550488521b50c19a3cb0bca1a 100755 --- a/distiller/thinning.py +++ b/distiller/thinning.py @@ -474,7 +474,7 @@ def execute_thinning_recipe(model, zeros_mask_dict, recipe, optimizer, loaded_fr This will remove filters and channels, as well as handle batch-normalization parameter adjustment, and thinning of weight tensors. """ - + device = distiller.utils.model_device(model) layers = {mod_name: m for mod_name, m in model.named_modules()} for layer_name, directives in recipe.modules.items(): for attr, val in directives.items(): @@ -500,7 +500,7 @@ def execute_thinning_recipe(model, zeros_mask_dict, recipe, optimizer, loaded_fr assert param is not None for directive in param_directives: dim = directive[0] - indices = directive[1] + indices = directive[1].to(device) len_indices = indices.nelement() if len(directive) == 4: # TODO: this code is hard to follow msglogger.debug("{}-{}-{}: SHAPE = {}".format(param_name, param.shape, id(param), list(directive[2])))