Skip to content
Snippets Groups Projects
Commit e564a05f authored by Neta Zmora's avatar Neta Zmora
Browse files

CPU support: fix thinning directive tensor migration to CPU/GPU

parent 81cb77d2
No related branches found
No related tags found
No related merge requests found
...@@ -474,7 +474,7 @@ def execute_thinning_recipe(model, zeros_mask_dict, recipe, optimizer, loaded_fr ...@@ -474,7 +474,7 @@ def execute_thinning_recipe(model, zeros_mask_dict, recipe, optimizer, loaded_fr
This will remove filters and channels, as well as handle batch-normalization parameter This will remove filters and channels, as well as handle batch-normalization parameter
adjustment, and thinning of weight tensors. adjustment, and thinning of weight tensors.
""" """
device = distiller.utils.model_device(model)
layers = {mod_name: m for mod_name, m in model.named_modules()} layers = {mod_name: m for mod_name, m in model.named_modules()}
for layer_name, directives in recipe.modules.items(): for layer_name, directives in recipe.modules.items():
for attr, val in directives.items(): for attr, val in directives.items():
...@@ -500,7 +500,7 @@ def execute_thinning_recipe(model, zeros_mask_dict, recipe, optimizer, loaded_fr ...@@ -500,7 +500,7 @@ def execute_thinning_recipe(model, zeros_mask_dict, recipe, optimizer, loaded_fr
assert param is not None assert param is not None
for directive in param_directives: for directive in param_directives:
dim = directive[0] dim = directive[0]
indices = directive[1] indices = directive[1].to(device)
len_indices = indices.nelement() len_indices = indices.nelement()
if len(directive) == 4: # TODO: this code is hard to follow if len(directive) == 4: # TODO: this code is hard to follow
msglogger.debug("{}-{}-{}: SHAPE = {}".format(param_name, param.shape, id(param), list(directive[2]))) msglogger.debug("{}-{}-{}: SHAPE = {}".format(param_name, param.shape, id(param), list(directive[2])))
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment