From 05bf755112683f6395c538c99b2c12857d494f3a Mon Sep 17 00:00:00 2001
From: Neta Zmora <neta.zmora@intel.com>
Date: Wed, 6 Feb 2019 15:52:45 +0200
Subject: [PATCH] Filter ranking: added support for L2-norm ranking with AGP
 pruning schedule

---
 distiller/pruning/__init__.py                 | 6 ++++--
 distiller/pruning/automated_gradual_pruner.py | 7 +++++++
 distiller/pruning/ranked_structures_pruner.py | 7 +++++--
 3 files changed, 16 insertions(+), 4 deletions(-)

diff --git a/distiller/pruning/__init__.py b/distiller/pruning/__init__.py
index 2f576e9..a26ecd5 100755
--- a/distiller/pruning/__init__.py
+++ b/distiller/pruning/__init__.py
@@ -19,14 +19,16 @@
 """
 
 from .magnitude_pruner import MagnitudeParameterPruner
-from .automated_gradual_pruner import AutomatedGradualPruner, L1RankedStructureParameterPruner_AGP, \
+from .automated_gradual_pruner import AutomatedGradualPruner, \
+                                      L1RankedStructureParameterPruner_AGP, L2RankedStructureParameterPruner_AGP, \
                                       ActivationAPoZRankedFilterPruner_AGP, GradientRankedFilterPruner_AGP, \
                                       RandomRankedFilterPruner_AGP
 from .level_pruner import SparsityLevelParameterPruner
 from .sensitivity_pruner import SensitivityPruner
 from .splicing_pruner import SplicingPruner
 from .structure_pruner import StructureParameterPruner
-from .ranked_structures_pruner import L1RankedStructureParameterPruner, ActivationAPoZRankedFilterPruner, \
+from .ranked_structures_pruner import L1RankedStructureParameterPruner, L2RankedStructureParameterPruner, \
+                                      ActivationAPoZRankedFilterPruner, \
                                       RandomRankedFilterPruner, GradientRankedFilterPruner
 from .baidu_rnn_pruner import BaiduRNNPruner
 
diff --git a/distiller/pruning/automated_gradual_pruner.py b/distiller/pruning/automated_gradual_pruner.py
index e267a68..9afaaef 100755
--- a/distiller/pruning/automated_gradual_pruner.py
+++ b/distiller/pruning/automated_gradual_pruner.py
@@ -105,6 +105,13 @@ class L1RankedStructureParameterPruner_AGP(StructuredAGP):
                                                        group_dependency=group_dependency, kwargs=kwargs)
 
 
+class L2RankedStructureParameterPruner_AGP(StructuredAGP):
+    def __init__(self, name, initial_sparsity, final_sparsity, group_type, weights, group_dependency=None, kwargs=None):
+        super().__init__(name, initial_sparsity, final_sparsity)
+        self.pruner = L2RankedStructureParameterPruner(name, group_type, desired_sparsity=0, weights=weights,
+                                                       group_dependency=group_dependency, kwargs=kwargs)
+
+
 class ActivationAPoZRankedFilterPruner_AGP(StructuredAGP):
     def __init__(self, name, initial_sparsity, final_sparsity, group_type, weights, group_dependency=None):
         assert group_type in ['3D', 'Filters']
diff --git a/distiller/pruning/ranked_structures_pruner.py b/distiller/pruning/ranked_structures_pruner.py
index 3af23c1..1a23c62 100755
--- a/distiller/pruning/ranked_structures_pruner.py
+++ b/distiller/pruning/ranked_structures_pruner.py
@@ -83,9 +83,12 @@ l2_magnitude = partial(torch.norm, p=2)
 
 
 class LpRankedStructureParameterPruner(RankedStructureParameterPruner):
-    """Uses mean L1-norm to rank and prune structures.
+    """Uses Lp-norm to rank and prune structures.
 
-    This class prunes to a prescribed percentage of structured-sparsity (level pruning).
+    This class prunes to a prescribed percentage of structured-sparsity (level pruning), by
+    first ranking (sorting) the structures based on their Lp-norm, and then pruning a perctenage
+    of the lower-ranking structures.
+    See also: https://en.wikipedia.org/wiki/Lp_space#The_p-norm_in_finite_dimensions
     """
     def __init__(self, name, group_type, desired_sparsity, weights,
                  group_dependency=None, kwargs=None, magnitude_fn=None):
-- 
GitLab