Skip to content
Snippets Groups Projects
Commit 947143a1 authored by Neta Zmora's avatar Neta Zmora
Browse files

Model thinning bug fix

This bug is triggered (for example) when you execute this example code:
python3 compress_classifier.py --arch resnet20_cifar  ../../../data.cifar10 -p=50 --lr=0.3 --epochs=180 --compress=../agp-pruning/resnet20_filters.schedule_agp_4.yaml --resume-from=../ssl/checkpoints/checkpoint_trained_dense.pth.tar --reset-optimizer --vs=0

The root-cause is a (non-functional) code refactoring made in commit
992291cf.

The problem is in an `if` statement handling the Optimizer reshaping.
The bad commit combined two `if` statements into one, and moved this
combined statement to a place where the control flows thru it when
it shouldn't.
The fix reverts back to the two original (and separate) `if` statements.
parent a0436c26
No related branches found
No related tags found
No related merge requests found
...@@ -521,6 +521,10 @@ def execute_thinning_recipe(model, zeros_mask_dict, recipe, optimizer, loaded_fr ...@@ -521,6 +521,10 @@ def execute_thinning_recipe(model, zeros_mask_dict, recipe, optimizer, loaded_fr
grad_selection_view = param.grad.resize_(*directive[2]) grad_selection_view = param.grad.resize_(*directive[2])
if grad_selection_view.size(dim) != len_indices: if grad_selection_view.size(dim) != len_indices:
param.grad = torch.index_select(grad_selection_view, dim, indices) param.grad = torch.index_select(grad_selection_view, dim, indices)
# update optimizer
if optimizer_thinning(optimizer, param, dim, indices, directive[3]):
msglogger.debug("Updated [4D] velocity buffer for {} (dim={},size={},shape={})".
format(param_name, dim, len_indices, directive[3]))
param.data = param.view(*directive[3]) param.data = param.view(*directive[3])
if param.grad is not None: if param.grad is not None:
...@@ -537,13 +541,9 @@ def execute_thinning_recipe(model, zeros_mask_dict, recipe, optimizer, loaded_fr ...@@ -537,13 +541,9 @@ def execute_thinning_recipe(model, zeros_mask_dict, recipe, optimizer, loaded_fr
# not exist, and therefore won't need to be re-dimensioned. # not exist, and therefore won't need to be re-dimensioned.
if param.grad is not None and param.grad.size(dim) != len_indices: if param.grad is not None and param.grad.size(dim) != len_indices:
param.grad = torch.index_select(param.grad, dim, indices.to(param.device)) param.grad = torch.index_select(param.grad, dim, indices.to(param.device))
# update optimizer
# update optimizer if optimizer_thinning(optimizer, param, dim, indices):
if optimizer_thinning(optimizer, param, dim, indices, msglogger.debug("Updated velocity buffer %s" % param_name)
new_shape=directive[3] if len(directive)==4 else None):
msglogger.debug("Updated velocity buffer %s" % param_name)
else:
msglogger.debug('Failed to update the optimizer by thinning directive')
if not loaded_from_file: if not loaded_from_file:
# If the masks are loaded from a checkpoint file, then we don't need to change # If the masks are loaded from a checkpoint file, then we don't need to change
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment