diff --git a/distiller/thinning.py b/distiller/thinning.py index edea055665405255e648325478fffbe306cd607a..43608b563db8bbf2f914ca178d95da560357a3b8 100755 --- a/distiller/thinning.py +++ b/distiller/thinning.py @@ -521,6 +521,10 @@ def execute_thinning_recipe(model, zeros_mask_dict, recipe, optimizer, loaded_fr grad_selection_view = param.grad.resize_(*directive[2]) if grad_selection_view.size(dim) != len_indices: param.grad = torch.index_select(grad_selection_view, dim, indices) + # update optimizer + if optimizer_thinning(optimizer, param, dim, indices, directive[3]): + msglogger.debug("Updated [4D] velocity buffer for {} (dim={},size={},shape={})". + format(param_name, dim, len_indices, directive[3])) param.data = param.view(*directive[3]) if param.grad is not None: @@ -537,13 +541,9 @@ def execute_thinning_recipe(model, zeros_mask_dict, recipe, optimizer, loaded_fr # not exist, and therefore won't need to be re-dimensioned. if param.grad is not None and param.grad.size(dim) != len_indices: param.grad = torch.index_select(param.grad, dim, indices.to(param.device)) - - # update optimizer - if optimizer_thinning(optimizer, param, dim, indices, - new_shape=directive[3] if len(directive)==4 else None): - msglogger.debug("Updated velocity buffer %s" % param_name) - else: - msglogger.debug('Failed to update the optimizer by thinning directive') + # update optimizer + if optimizer_thinning(optimizer, param, dim, indices): + msglogger.debug("Updated velocity buffer %s" % param_name) if not loaded_from_file: # If the masks are loaded from a checkpoint file, then we don't need to change