From 947143a159d78c6e809cee921ca1cecac2af6d31 Mon Sep 17 00:00:00 2001 From: Neta Zmora <neta.zmora@intel.com> Date: Thu, 4 Jul 2019 15:42:50 +0300 Subject: [PATCH] Model thinning bug fix This bug is triggered (for example) when you execute this example code: python3 compress_classifier.py --arch resnet20_cifar ../../../data.cifar10 -p=50 --lr=0.3 --epochs=180 --compress=../agp-pruning/resnet20_filters.schedule_agp_4.yaml --resume-from=../ssl/checkpoints/checkpoint_trained_dense.pth.tar --reset-optimizer --vs=0 The root-cause is a (non-functional) code refactoring made in commit 992291cfe4a2d917eef6f926f58dc10a68f82105. The problem is in an `if` statement handling the Optimizer reshaping. The bad commit combined two `if` statements into one, and moved this combined statement to a place where the control flows thru it when it shouldn't. The fix reverts back to the two original (and separate) `if` statements. --- distiller/thinning.py | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/distiller/thinning.py b/distiller/thinning.py index edea055..43608b5 100755 --- a/distiller/thinning.py +++ b/distiller/thinning.py @@ -521,6 +521,10 @@ def execute_thinning_recipe(model, zeros_mask_dict, recipe, optimizer, loaded_fr grad_selection_view = param.grad.resize_(*directive[2]) if grad_selection_view.size(dim) != len_indices: param.grad = torch.index_select(grad_selection_view, dim, indices) + # update optimizer + if optimizer_thinning(optimizer, param, dim, indices, directive[3]): + msglogger.debug("Updated [4D] velocity buffer for {} (dim={},size={},shape={})". + format(param_name, dim, len_indices, directive[3])) param.data = param.view(*directive[3]) if param.grad is not None: @@ -537,13 +541,9 @@ def execute_thinning_recipe(model, zeros_mask_dict, recipe, optimizer, loaded_fr # not exist, and therefore won't need to be re-dimensioned. if param.grad is not None and param.grad.size(dim) != len_indices: param.grad = torch.index_select(param.grad, dim, indices.to(param.device)) - - # update optimizer - if optimizer_thinning(optimizer, param, dim, indices, - new_shape=directive[3] if len(directive)==4 else None): - msglogger.debug("Updated velocity buffer %s" % param_name) - else: - msglogger.debug('Failed to update the optimizer by thinning directive') + # update optimizer + if optimizer_thinning(optimizer, param, dim, indices): + msglogger.debug("Updated velocity buffer %s" % param_name) if not loaded_from_file: # If the masks are loaded from a checkpoint file, then we don't need to change -- GitLab