diff --git a/distiller/pruning/ranked_structures_pruner.py b/distiller/pruning/ranked_structures_pruner.py
index 1ca9237a42a5d83616c9566ee1487599d8317dd7..ca1a4a1d257ee0a240d38c3f35a4d0e554cd82b4 100755
--- a/distiller/pruning/ranked_structures_pruner.py
+++ b/distiller/pruning/ranked_structures_pruner.py
@@ -407,13 +407,16 @@ class GradientRankedFilterPruner(RankedStructureParameterPruner):
 
     def rank_and_prune_filters(self, fraction_to_prune, param, param_name, zeros_mask_dict, model, binary_map=None):
         assert param.dim() == 4, "This thresholding is only supported for 4D weights"
+        if param.grad is None:
+            msglogger.info("Skipping gradient pruning of %s because it does not have a gradient yet", param_name)
+            return
         num_filters = param.size(0)
         num_filters_to_prune = int(fraction_to_prune * num_filters)
         if num_filters_to_prune == 0:
             msglogger.info("Too few filters - can't prune %.1f%% filters", 100*fraction_to_prune)
             return
 
-        # Compute the multiplicatipn of the filters times the filter_gradienrs
+        # Compute the multiplication of the filters times the filter_gradienrs
         view_filters = param.view(param.size(0), -1)
         view_filter_grads = param.grad.view(param.size(0), -1)
         weighted_gradients = view_filter_grads * view_filters
@@ -421,7 +424,7 @@ class GradientRankedFilterPruner(RankedStructureParameterPruner):
 
         # Sort from high to low, and remove the bottom 'num_filters_to_prune' filters
         filters_ordered_by_gradient = np.argsort(-weighted_gradients.detach().cpu().numpy())[:-num_filters_to_prune]
-        mask, binary_map = mask_from_filter_order(filters_ordered_by_gradient, param, num_filters)
+        mask, binary_map = mask_from_filter_order(filters_ordered_by_gradient, param, num_filters, binary_map)
         zeros_mask_dict[param_name].mask = mask
 
         msglogger.info("GradientRankedFilterPruner - param: %s pruned=%.3f goal=%.3f (%d/%d)",