From 6d7288a82046bd64a0356623e67851e3fabc25f4 Mon Sep 17 00:00:00 2001 From: Neta Zmora <neta.zmora@intel.com> Date: Wed, 9 Jan 2019 12:23:59 +0200 Subject: [PATCH] Fix for GradientRankedFilterPruner A parameter was missing from one of the function calls. --- distiller/pruning/ranked_structures_pruner.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/distiller/pruning/ranked_structures_pruner.py b/distiller/pruning/ranked_structures_pruner.py index 1ca9237..ca1a4a1 100755 --- a/distiller/pruning/ranked_structures_pruner.py +++ b/distiller/pruning/ranked_structures_pruner.py @@ -407,13 +407,16 @@ class GradientRankedFilterPruner(RankedStructureParameterPruner): def rank_and_prune_filters(self, fraction_to_prune, param, param_name, zeros_mask_dict, model, binary_map=None): assert param.dim() == 4, "This thresholding is only supported for 4D weights" + if param.grad is None: + msglogger.info("Skipping gradient pruning of %s because it does not have a gradient yet", param_name) + return num_filters = param.size(0) num_filters_to_prune = int(fraction_to_prune * num_filters) if num_filters_to_prune == 0: msglogger.info("Too few filters - can't prune %.1f%% filters", 100*fraction_to_prune) return - # Compute the multiplicatipn of the filters times the filter_gradienrs + # Compute the multiplication of the filters times the filter_gradienrs view_filters = param.view(param.size(0), -1) view_filter_grads = param.grad.view(param.size(0), -1) weighted_gradients = view_filter_grads * view_filters @@ -421,7 +424,7 @@ class GradientRankedFilterPruner(RankedStructureParameterPruner): # Sort from high to low, and remove the bottom 'num_filters_to_prune' filters filters_ordered_by_gradient = np.argsort(-weighted_gradients.detach().cpu().numpy())[:-num_filters_to_prune] - mask, binary_map = mask_from_filter_order(filters_ordered_by_gradient, param, num_filters) + mask, binary_map = mask_from_filter_order(filters_ordered_by_gradient, param, num_filters, binary_map) zeros_mask_dict[param_name].mask = mask msglogger.info("GradientRankedFilterPruner - param: %s pruned=%.3f goal=%.3f (%d/%d)", -- GitLab