From 2179ec50d2b586d33997c6a1f48fc10ebbf497a8 Mon Sep 17 00:00:00 2001
From: Neta Zmora <neta.zmora@intel.com>
Date: Tue, 5 Feb 2019 16:34:09 +0200
Subject: [PATCH] Filter ranking: add support for ranking by L2 magnitude

---
 distiller/pruning/ranked_structures_pruner.py | 95 ++++++++++++-------
 distiller/thresholding.py                     | 10 +-
 2 files changed, 70 insertions(+), 35 deletions(-)

diff --git a/distiller/pruning/ranked_structures_pruner.py b/distiller/pruning/ranked_structures_pruner.py
index ca1a4a1..3af23c1 100755
--- a/distiller/pruning/ranked_structures_pruner.py
+++ b/distiller/pruning/ranked_structures_pruner.py
@@ -78,17 +78,25 @@ class RankedStructureParameterPruner(_ParameterPruner):
         raise NotImplementedError
 
 
-class L1RankedStructureParameterPruner(RankedStructureParameterPruner):
+l1_magnitude = partial(torch.norm, p=1)
+l2_magnitude = partial(torch.norm, p=2)
+
+
+class LpRankedStructureParameterPruner(RankedStructureParameterPruner):
     """Uses mean L1-norm to rank and prune structures.
 
     This class prunes to a prescribed percentage of structured-sparsity (level pruning).
     """
-    def __init__(self, name, group_type, desired_sparsity, weights, group_dependency=None, kwargs=None):
+    def __init__(self, name, group_type, desired_sparsity, weights,
+                 group_dependency=None, kwargs=None, magnitude_fn=None):
         super().__init__(name, group_type, desired_sparsity, weights, group_dependency)
         if group_type not in ['3D', 'Filters', 'Channels', 'Rows', 'Blocks']:
             raise ValueError("Structure {} was requested but "
                              "currently ranking of this shape is not supported".
                              format(group_type))
+        assert magnitude_fn is not None
+        self.magnitude_fn = magnitude_fn
+
         if group_type == 'Blocks':
             try:
                 self.block_shape = kwargs['block_shape']
@@ -101,18 +109,19 @@ class L1RankedStructureParameterPruner(RankedStructureParameterPruner):
         if self.group_type in ['3D', 'Filters']:
             group_pruning_fn = self.rank_and_prune_filters
         elif self.group_type == 'Channels':
-            group_pruning_fn = self.rank_and_prune_channels
+            group_pruning_fn = partial(self.rank_and_prune_channels)
         elif self.group_type == 'Rows':
             group_pruning_fn = self.rank_and_prune_rows
         elif self.group_type == 'Blocks':
             group_pruning_fn = partial(self.rank_and_prune_blocks, block_shape=self.block_shape)
 
-        binary_map = group_pruning_fn(fraction_to_prune, param, param_name, zeros_mask_dict, model, binary_map)
+        binary_map = group_pruning_fn(fraction_to_prune, param, param_name,
+                                      zeros_mask_dict, model, binary_map, self.magnitude_fn)
         return binary_map
 
     @staticmethod
     def rank_and_prune_channels(fraction_to_prune, param, param_name=None,
-                                zeros_mask_dict=None, model=None, binary_map=None):
+                                zeros_mask_dict=None, model=None, binary_map=None, magnitude_fn=l1_magnitude):
         def rank_channels(fraction_to_prune, param):
             num_filters = param.size(0)
             num_channels = param.size(1)
@@ -122,9 +131,9 @@ class L1RankedStructureParameterPruner(RankedStructureParameterPruner):
             # tensor, is now a row in the 2D tensor.
             view_2d = param.view(-1, kernel_size)
             # Next, compute the sums of each kernel
-            kernel_sums = view_2d.abs().sum(dim=1)
+            kernel_mags = magnitude_fn(view_2d, dim=1)
             # Now group by channels
-            k_sums_mat = kernel_sums.view(num_filters, num_channels).t()
+            k_sums_mat = kernel_mags.view(num_filters, num_channels).t()
             channel_mags = k_sums_mat.mean(dim=1)
             k = int(fraction_to_prune * channel_mags.size(0))
             if k == 0:
@@ -160,14 +169,14 @@ class L1RankedStructureParameterPruner(RankedStructureParameterPruner):
 
     @staticmethod
     def rank_and_prune_filters(fraction_to_prune, param, param_name,
-                               zeros_mask_dict, model=None, binary_map=None):
+                               zeros_mask_dict, model=None, binary_map=None, magnitude_fn=l1_magnitude):
         assert param.dim() == 4, "This thresholding is only supported for 4D weights"
 
         threshold = None
         if binary_map is None:
             # First we rank the filters
             view_filters = param.view(param.size(0), -1)
-            filter_mags = view_filters.data.abs().mean(dim=1)
+            filter_mags = magnitude_fn(view_filters, dim=1)
             topk_filters = int(fraction_to_prune * filter_mags.size(0))
             if topk_filters == 0:
                 msglogger.info("Too few filters - can't prune %.1f%% filters", 100*fraction_to_prune)
@@ -178,7 +187,8 @@ class L1RankedStructureParameterPruner(RankedStructureParameterPruner):
                            param_name,
                            topk_filters, filter_mags.size(0))
         # Then we threshold
-        mask, binary_map = distiller.group_threshold_mask(param, 'Filters', threshold, 'Mean_Abs', binary_map)
+        threshold_type = 'L1' if magnitude_fn == l1_magnitude else 'L2'
+        mask, binary_map = distiller.group_threshold_mask(param, 'Filters', threshold, threshold_type, binary_map)
         if zeros_mask_dict is not None:
             zeros_mask_dict[param_name].mask = mask
         msglogger.info("L1RankedStructureParameterPruner - param: %s pruned=%.3f goal=%.3f",
@@ -189,7 +199,7 @@ class L1RankedStructureParameterPruner(RankedStructureParameterPruner):
 
     @staticmethod
     def rank_and_prune_rows(fraction_to_prune, param, param_name,
-                            zeros_mask_dict, model=None, binary_map=None):
+                            zeros_mask_dict, model=None, binary_map=None, magnitude_fn=l1_magnitude):
         """Prune the rows of a matrix, based on ranked L1-norms of the matrix rows.
 
         PyTorch stores the weights matrices in a transposed format.  I.e. before performing GEMM, a matrix is
@@ -203,21 +213,23 @@ class L1RankedStructureParameterPruner(RankedStructureParameterPruner):
         assert param.dim() == 2, "This thresholding is only supported for 2D weights"
         ROWS_DIM = 0
         THRESHOLD_DIM = 'Cols'
-        rows_mags = param.abs().mean(dim=ROWS_DIM)
+        rows_mags = magnitude_fn(param, dim=ROWS_DIM)
         num_rows_to_prune = int(fraction_to_prune * rows_mags.size(0))
         if num_rows_to_prune == 0:
             msglogger.info("Too few filters - can't prune %.1f%% rows", 100*fraction_to_prune)
             return
         bottomk_rows, _ = torch.topk(rows_mags, num_rows_to_prune, largest=False, sorted=True)
         threshold = bottomk_rows[-1]
-        zeros_mask_dict[param_name].mask = distiller.group_threshold_mask(param, THRESHOLD_DIM, threshold, 'Mean_Abs')
+        threshold_type = 'L1' if magnitude_fn == l1_magnitude else 'L2'
+        zeros_mask_dict[param_name].mask = distiller.group_threshold_mask(param, THRESHOLD_DIM,
+                                                                          threshold, threshold_type)
         msglogger.info("L1RankedStructureParameterPruner - param: %s pruned=%.3f goal=%.3f (%d/%d)", param_name,
                        distiller.sparsity(zeros_mask_dict[param_name].mask),
                        fraction_to_prune, num_rows_to_prune, rows_mags.size(0))
 
     @staticmethod
-    def rank_and_prune_blocks(fraction_to_prune, param, param_name=None,
-                              zeros_mask_dict=None, model=None, binary_map=None, block_shape=None):
+    def rank_and_prune_blocks(fraction_to_prune, param, param_name=None, zeros_mask_dict=None,
+                              model=None, binary_map=None, block_shape=None, magnitude_fn=l1_magnitude):
         """Block-wise pruning for 4D tensors.
 
         The block shape is specified using a tuple: [block_repetitions, block_depth, block_height, block_width].
@@ -251,26 +263,20 @@ class L1RankedStructureParameterPruner(RankedStructureParameterPruner):
         kernel_size = param.size(2) * param.size(3)
 
         if block_depth > 1:
-            view_dims = (
-                num_filters*num_channels//(block_repetitions*block_depth),
-                block_repetitions*block_depth,
-                kernel_size,
-                )
+            view_dims = (num_filters*num_channels//(block_repetitions*block_depth),
+                         block_repetitions*block_depth,
+                         kernel_size,)
         else:
-            view_dims = (
-                num_filters // block_repetitions,
-                block_repetitions,
-                -1,
-                )
+            view_dims = (num_filters // block_repetitions,
+                         block_repetitions,
+                         -1,)
 
         def rank_blocks(fraction_to_prune, param):
             # Create a view where each block is a column
             view1 = param.view(*view_dims)
             # Next, compute the sums of each column (block)
-            block_sums = view1.abs().sum(dim=1)
-
-            # Now group by channels
-            block_mags = block_sums.view(-1)  # flatten
+            block_mags = magnitude_fn(view1, dim=1)
+            block_mags = block_mags.view(-1)  # flatten
             k = int(fraction_to_prune * block_mags.size(0))
             if k == 0:
                 msglogger.info("Too few blocks (%d)- can't prune %.1f%% blocks",
@@ -302,6 +308,28 @@ class L1RankedStructureParameterPruner(RankedStructureParameterPruner):
         return binary_map
 
 
+class L1RankedStructureParameterPruner(LpRankedStructureParameterPruner):
+    """Uses mean L1-norm to rank and prune structures.
+
+    This class prunes to a prescribed percentage of structured-sparsity (level pruning).
+    """
+    def __init__(self, name, group_type, desired_sparsity, weights,
+                 group_dependency=None, kwargs=None):
+        super().__init__(name, group_type, desired_sparsity, weights,
+                         group_dependency, kwargs, magnitude_fn=l1_magnitude)
+
+
+class L2RankedStructureParameterPruner(LpRankedStructureParameterPruner):
+    """Uses mean L2-norm to rank and prune structures.
+
+    This class prunes to a prescribed percentage of structured-sparsity (level pruning).
+    """
+    def __init__(self, name, group_type, desired_sparsity, weights,
+                 group_dependency=None, kwargs=None):
+        super().__init__(name, group_type, desired_sparsity, weights,
+                         group_dependency, kwargs, magnitude_fn=l2_magnitude)
+
+
 def mask_from_filter_order(filters_ordered_by_criterion, param, num_filters, binary_map):
     if binary_map is None:
         binary_map = torch.zeros(num_filters).cuda()
@@ -324,7 +352,8 @@ class ActivationAPoZRankedFilterPruner(RankedStructureParameterPruner):
     def prune_group(self, fraction_to_prune, param, param_name, zeros_mask_dict, model=None, binary_map=None):
         if fraction_to_prune == 0:
             return
-        binary_map = self.rank_and_prune_filters(fraction_to_prune, param, param_name, zeros_mask_dict, model, binary_map)
+        binary_map = self.rank_and_prune_filters(fraction_to_prune, param, param_name,
+                                                 zeros_mask_dict, model, binary_map)
         return binary_map
 
     def rank_and_prune_filters(self, fraction_to_prune, param, param_name, zeros_mask_dict, model, binary_map=None):
@@ -370,7 +399,8 @@ class RandomRankedFilterPruner(RankedStructureParameterPruner):
     def prune_group(self, fraction_to_prune, param, param_name, zeros_mask_dict, model=None, binary_map=None):
         if fraction_to_prune == 0:
             return
-        binary_map = self.rank_and_prune_filters(fraction_to_prune, param, param_name, zeros_mask_dict, model, binary_map)
+        binary_map = self.rank_and_prune_filters(fraction_to_prune, param, param_name,
+                                                 zeros_mask_dict, model, binary_map)
         return binary_map
 
     def rank_and_prune_filters(self, fraction_to_prune, param, param_name, zeros_mask_dict, model, binary_map=None):
@@ -402,7 +432,8 @@ class GradientRankedFilterPruner(RankedStructureParameterPruner):
     def prune_group(self, fraction_to_prune, param, param_name, zeros_mask_dict, model=None, binary_map=None):
         if fraction_to_prune == 0:
             return
-        binary_map = self.rank_and_prune_filters(fraction_to_prune, param, param_name, zeros_mask_dict, model, binary_map)
+        binary_map = self.rank_and_prune_filters(fraction_to_prune, param, param_name,
+                                                 zeros_mask_dict, model, binary_map)
         return binary_map
 
     def rank_and_prune_filters(self, fraction_to_prune, param, param_name, zeros_mask_dict, model, binary_map=None):
diff --git a/distiller/thresholding.py b/distiller/thresholding.py
index be03a27..4f0b845 100755
--- a/distiller/thresholding.py
+++ b/distiller/thresholding.py
@@ -151,7 +151,7 @@ def group_threshold_mask(param, group_type, threshold, threshold_criteria, binar
         if binary_map is None:
             binary_map = group_threshold_binary_map(param, group_type, threshold, threshold_criteria)
         a = binary_map.expand(param.size(1) * param.size(2) * param.size(3), param.size(0)).t()
-        return a.view(param.size(0), param.size(1), param.size(2), param.size(3)), binary_map
+        return a.view(*param.shape), binary_map
 
     elif group_type == '4D':
         assert param.dim() == 4, "This thresholding is only supported for 4D weights"
@@ -181,10 +181,14 @@ def group_threshold_mask(param, group_type, threshold, threshold_criteria, binar
 def threshold_policy(weights, thresholds, threshold_criteria, dim=1):
     """
     """
-    if threshold_criteria == 'Mean_Abs':
-        return weights.data.abs().mean(dim=dim).gt(thresholds).type(weights.type())
+    if threshold_criteria in ['Mean_Abs', 'Mean_L1']:
+        return weights.data.norm(p=1, dim=dim).div(weights.size(dim)).gt(thresholds).type(weights.type())
+    if threshold_criteria == 'Mean_L2':
+        return weights.data.norm(p=2, dim=dim).div(weights.size(dim)).gt(thresholds).type(weights.type())
     elif threshold_criteria == 'L1':
         return weights.data.norm(p=1, dim=dim).gt(thresholds).type(weights.type())
+    elif threshold_criteria == 'L2':
+        return weights.data.norm(p=2, dim=dim).gt(thresholds).type(weights.type())
     elif threshold_criteria == 'Max':
         maxv, _ = weights.data.abs().max(dim=dim)
         return maxv.gt(thresholds).type(weights.type())
-- 
GitLab