From 6d7288a82046bd64a0356623e67851e3fabc25f4 Mon Sep 17 00:00:00 2001
From: Neta Zmora <neta.zmora@intel.com>
Date: Wed, 9 Jan 2019 12:23:59 +0200
Subject: [PATCH] Fix for GradientRankedFilterPruner

A parameter was missing from one of the function calls.
---
 distiller/pruning/ranked_structures_pruner.py | 7 +++++--
 1 file changed, 5 insertions(+), 2 deletions(-)

diff --git a/distiller/pruning/ranked_structures_pruner.py b/distiller/pruning/ranked_structures_pruner.py
index 1ca9237..ca1a4a1 100755
--- a/distiller/pruning/ranked_structures_pruner.py
+++ b/distiller/pruning/ranked_structures_pruner.py
@@ -407,13 +407,16 @@ class GradientRankedFilterPruner(RankedStructureParameterPruner):
 
     def rank_and_prune_filters(self, fraction_to_prune, param, param_name, zeros_mask_dict, model, binary_map=None):
         assert param.dim() == 4, "This thresholding is only supported for 4D weights"
+        if param.grad is None:
+            msglogger.info("Skipping gradient pruning of %s because it does not have a gradient yet", param_name)
+            return
         num_filters = param.size(0)
         num_filters_to_prune = int(fraction_to_prune * num_filters)
         if num_filters_to_prune == 0:
             msglogger.info("Too few filters - can't prune %.1f%% filters", 100*fraction_to_prune)
             return
 
-        # Compute the multiplicatipn of the filters times the filter_gradienrs
+        # Compute the multiplication of the filters times the filter_gradienrs
         view_filters = param.view(param.size(0), -1)
         view_filter_grads = param.grad.view(param.size(0), -1)
         weighted_gradients = view_filter_grads * view_filters
@@ -421,7 +424,7 @@ class GradientRankedFilterPruner(RankedStructureParameterPruner):
 
         # Sort from high to low, and remove the bottom 'num_filters_to_prune' filters
         filters_ordered_by_gradient = np.argsort(-weighted_gradients.detach().cpu().numpy())[:-num_filters_to_prune]
-        mask, binary_map = mask_from_filter_order(filters_ordered_by_gradient, param, num_filters)
+        mask, binary_map = mask_from_filter_order(filters_ordered_by_gradient, param, num_filters, binary_map)
         zeros_mask_dict[param_name].mask = mask
 
         msglogger.info("GradientRankedFilterPruner - param: %s pruned=%.3f goal=%.3f (%d/%d)",
-- 
GitLab