From d75ba51e74b47d71886cd30abd5abfee5750ecc5 Mon Sep 17 00:00:00 2001
From: Neta Zmora <neta.zmora@intel.com>
Date: Sun, 17 Mar 2019 13:39:16 +0200
Subject: [PATCH] Fix row-pruning

A recent change requires us to return the binary_map from the
ranking operation, and this was missing for the row-pruning case.
---
 distiller/pruning/ranked_structures_pruner.py | 5 +++--
 1 file changed, 3 insertions(+), 2 deletions(-)

diff --git a/distiller/pruning/ranked_structures_pruner.py b/distiller/pruning/ranked_structures_pruner.py
index b5f990a..8be38b5 100755
--- a/distiller/pruning/ranked_structures_pruner.py
+++ b/distiller/pruning/ranked_structures_pruner.py
@@ -228,11 +228,12 @@ class LpRankedStructureParameterPruner(RankedStructureParameterPruner):
         bottomk_rows, _ = torch.topk(rows_mags, num_rows_to_prune, largest=False, sorted=True)
         threshold = bottomk_rows[-1]
         threshold_type = 'L1' if magnitude_fn == l1_magnitude else 'L2'
-        zeros_mask_dict[param_name].mask = distiller.group_threshold_mask(param, THRESHOLD_DIM,
-                                                                          threshold, threshold_type)
+        zeros_mask_dict[param_name].mask, binary_map = distiller.group_threshold_mask(param, THRESHOLD_DIM,
+                                                                                      threshold, threshold_type)
         msglogger.info("L1RankedStructureParameterPruner - param: %s pruned=%.3f goal=%.3f (%d/%d)", param_name,
                        distiller.sparsity(zeros_mask_dict[param_name].mask),
                        fraction_to_prune, num_rows_to_prune, rows_mags.size(0))
+        return binary_map
 
     @staticmethod
     def rank_and_prune_blocks(fraction_to_prune, param, param_name=None, zeros_mask_dict=None,
-- 
GitLab