From 05bf755112683f6395c538c99b2c12857d494f3a Mon Sep 17 00:00:00 2001 From: Neta Zmora <neta.zmora@intel.com> Date: Wed, 6 Feb 2019 15:52:45 +0200 Subject: [PATCH] Filter ranking: added support for L2-norm ranking with AGP pruning schedule --- distiller/pruning/__init__.py | 6 ++++-- distiller/pruning/automated_gradual_pruner.py | 7 +++++++ distiller/pruning/ranked_structures_pruner.py | 7 +++++-- 3 files changed, 16 insertions(+), 4 deletions(-) diff --git a/distiller/pruning/__init__.py b/distiller/pruning/__init__.py index 2f576e9..a26ecd5 100755 --- a/distiller/pruning/__init__.py +++ b/distiller/pruning/__init__.py @@ -19,14 +19,16 @@ """ from .magnitude_pruner import MagnitudeParameterPruner -from .automated_gradual_pruner import AutomatedGradualPruner, L1RankedStructureParameterPruner_AGP, \ +from .automated_gradual_pruner import AutomatedGradualPruner, \ + L1RankedStructureParameterPruner_AGP, L2RankedStructureParameterPruner_AGP, \ ActivationAPoZRankedFilterPruner_AGP, GradientRankedFilterPruner_AGP, \ RandomRankedFilterPruner_AGP from .level_pruner import SparsityLevelParameterPruner from .sensitivity_pruner import SensitivityPruner from .splicing_pruner import SplicingPruner from .structure_pruner import StructureParameterPruner -from .ranked_structures_pruner import L1RankedStructureParameterPruner, ActivationAPoZRankedFilterPruner, \ +from .ranked_structures_pruner import L1RankedStructureParameterPruner, L2RankedStructureParameterPruner, \ + ActivationAPoZRankedFilterPruner, \ RandomRankedFilterPruner, GradientRankedFilterPruner from .baidu_rnn_pruner import BaiduRNNPruner diff --git a/distiller/pruning/automated_gradual_pruner.py b/distiller/pruning/automated_gradual_pruner.py index e267a68..9afaaef 100755 --- a/distiller/pruning/automated_gradual_pruner.py +++ b/distiller/pruning/automated_gradual_pruner.py @@ -105,6 +105,13 @@ class L1RankedStructureParameterPruner_AGP(StructuredAGP): group_dependency=group_dependency, kwargs=kwargs) +class L2RankedStructureParameterPruner_AGP(StructuredAGP): + def __init__(self, name, initial_sparsity, final_sparsity, group_type, weights, group_dependency=None, kwargs=None): + super().__init__(name, initial_sparsity, final_sparsity) + self.pruner = L2RankedStructureParameterPruner(name, group_type, desired_sparsity=0, weights=weights, + group_dependency=group_dependency, kwargs=kwargs) + + class ActivationAPoZRankedFilterPruner_AGP(StructuredAGP): def __init__(self, name, initial_sparsity, final_sparsity, group_type, weights, group_dependency=None): assert group_type in ['3D', 'Filters'] diff --git a/distiller/pruning/ranked_structures_pruner.py b/distiller/pruning/ranked_structures_pruner.py index 3af23c1..1a23c62 100755 --- a/distiller/pruning/ranked_structures_pruner.py +++ b/distiller/pruning/ranked_structures_pruner.py @@ -83,9 +83,12 @@ l2_magnitude = partial(torch.norm, p=2) class LpRankedStructureParameterPruner(RankedStructureParameterPruner): - """Uses mean L1-norm to rank and prune structures. + """Uses Lp-norm to rank and prune structures. - This class prunes to a prescribed percentage of structured-sparsity (level pruning). + This class prunes to a prescribed percentage of structured-sparsity (level pruning), by + first ranking (sorting) the structures based on their Lp-norm, and then pruning a perctenage + of the lower-ranking structures. + See also: https://en.wikipedia.org/wiki/Lp_space#The_p-norm_in_finite_dimensions """ def __init__(self, name, group_type, desired_sparsity, weights, group_dependency=None, kwargs=None, magnitude_fn=None): -- GitLab