From ba1ee25b90b6fa6d85763d04c34f087f2016f986 Mon Sep 17 00:00:00 2001
From: Bar <elhararb@gmail.com>
Date: Sun, 17 Mar 2019 14:05:26 +0200
Subject: [PATCH] Fix logging message for structured pruning (#194)

Modify LpRankedStructureParameterPruner to log as
the correct class name for both L1 and L2 pruners.
---
 distiller/pruning/ranked_structures_pruner.py | 21 ++++++++++++-------
 1 file changed, 13 insertions(+), 8 deletions(-)

diff --git a/distiller/pruning/ranked_structures_pruner.py b/distiller/pruning/ranked_structures_pruner.py
index 2ee79cb..a38c87c 100755
--- a/distiller/pruning/ranked_structures_pruner.py
+++ b/distiller/pruning/ranked_structures_pruner.py
@@ -165,9 +165,11 @@ class LpRankedStructureParameterPruner(RankedStructureParameterPruner):
             threshold = bottomk_channels[-1]
             binary_map = channel_mags.gt(threshold).type(param.data.type())
 
+        threshold_type = 'L1' if magnitude_fn == l1_magnitude else 'L2'
         if zeros_mask_dict is not None:
             zeros_mask_dict[param_name].mask = binary_map_to_mask(binary_map, param)
-            msglogger.info("L1RankedStructureParameterPruner - param: %s pruned=%.3f goal=%.3f (%d/%d)", param_name,
+            msglogger.info("%sRankedStructureParameterPruner - param: %s pruned=%.3f goal=%.3f (%d/%d)",
+                           threshold_type, param_name,
                            distiller.sparsity_ch(zeros_mask_dict[param_name].mask),
                            fraction_to_prune, binary_map.sum().item(), param.size(1))
         return binary_map
@@ -178,6 +180,7 @@ class LpRankedStructureParameterPruner(RankedStructureParameterPruner):
         assert param.dim() == 4, "This pruning is only supported for 4D weights"
 
         threshold = None
+        threshold_type = 'L1' if magnitude_fn == l1_magnitude else 'L2'
         if binary_map is None:
             # First we rank the filters
             view_filters = param.view(param.size(0), -1)
@@ -188,16 +191,15 @@ class LpRankedStructureParameterPruner(RankedStructureParameterPruner):
                 return
             bottomk, _ = torch.topk(filter_mags, topk_filters, largest=False, sorted=True)
             threshold = bottomk[-1]
-            msglogger.info("L1RankedStructureParameterPruner - param: %s pruned=(%d/%d)",
-                           param_name,
+            msglogger.info("%sRankedStructureParameterPruner - param: %s pruned=(%d/%d)",
+                           threshold_type, param_name,
                            topk_filters, filter_mags.size(0))
         # Then we threshold
-        threshold_type = 'L1' if magnitude_fn == l1_magnitude else 'L2'
         mask, binary_map = distiller.group_threshold_mask(param, 'Filters', threshold, threshold_type, binary_map)
         if zeros_mask_dict is not None:
             zeros_mask_dict[param_name].mask = mask
-        msglogger.info("L1RankedStructureParameterPruner - param: %s pruned=%.3f goal=%.3f",
-                       param_name,
+        msglogger.info("%sRankedStructureParameterPruner - param: %s pruned=%.3f goal=%.3f",
+                       threshold_type, param_name,
                        distiller.sparsity(mask),
                        fraction_to_prune)
 
@@ -231,7 +233,8 @@ class LpRankedStructureParameterPruner(RankedStructureParameterPruner):
         threshold_type = 'L1' if magnitude_fn == l1_magnitude else 'L2'
         zeros_mask_dict[param_name].mask, binary_map = distiller.group_threshold_mask(param, THRESHOLD_DIM,
                                                                                       threshold, threshold_type)
-        msglogger.info("L1RankedStructureParameterPruner - param: %s pruned=%.3f goal=%.3f (%d/%d)", param_name,
+        msglogger.info("%sRankedStructureParameterPruner - param: %s pruned=%.3f goal=%.3f (%d/%d)",
+                       threshold_type, param_name,
                        distiller.sparsity(zeros_mask_dict[param_name].mask),
                        fraction_to_prune, num_rows_to_prune, rows_mags.size(0))
         return binary_map
@@ -309,9 +312,11 @@ class LpRankedStructureParameterPruner(RankedStructureParameterPruner):
             threshold = bottomk_blocks[-1]
             binary_map = block_mags.gt(threshold).type(param.data.type())
 
+        threshold_type = 'L1' if magnitude_fn == l1_magnitude else 'L2'
         if zeros_mask_dict is not None:
             zeros_mask_dict[param_name].mask = binary_map_to_mask(binary_map, param)
-            msglogger.info("L1RankedStructureParameterPruner - param: %s pruned=%.3f goal=%.3f (%d/%d)", param_name,
+            msglogger.info("%sRankedStructureParameterPruner - param: %s pruned=%.3f goal=%.3f (%d/%d)",
+                           threshold_type, param_name,
                            distiller.sparsity_blocks(zeros_mask_dict[param_name].mask, block_shape=block_shape),
                            fraction_to_prune, binary_map.sum().item(), num_super_blocks)
         return binary_map
-- 
GitLab