diff --git a/distiller/pruning/__init__.py b/distiller/pruning/__init__.py index 2f576e9dbab717b153cfa836e4507a98fdad4f76..a26ecd547bbd958075f6023402877f56d723e865 100755 --- a/distiller/pruning/__init__.py +++ b/distiller/pruning/__init__.py @@ -19,14 +19,16 @@ """ from .magnitude_pruner import MagnitudeParameterPruner -from .automated_gradual_pruner import AutomatedGradualPruner, L1RankedStructureParameterPruner_AGP, \ +from .automated_gradual_pruner import AutomatedGradualPruner, \ + L1RankedStructureParameterPruner_AGP, L2RankedStructureParameterPruner_AGP, \ ActivationAPoZRankedFilterPruner_AGP, GradientRankedFilterPruner_AGP, \ RandomRankedFilterPruner_AGP from .level_pruner import SparsityLevelParameterPruner from .sensitivity_pruner import SensitivityPruner from .splicing_pruner import SplicingPruner from .structure_pruner import StructureParameterPruner -from .ranked_structures_pruner import L1RankedStructureParameterPruner, ActivationAPoZRankedFilterPruner, \ +from .ranked_structures_pruner import L1RankedStructureParameterPruner, L2RankedStructureParameterPruner, \ + ActivationAPoZRankedFilterPruner, \ RandomRankedFilterPruner, GradientRankedFilterPruner from .baidu_rnn_pruner import BaiduRNNPruner diff --git a/distiller/pruning/automated_gradual_pruner.py b/distiller/pruning/automated_gradual_pruner.py index e267a686df218f67065f04b2e93a5e6407f5554d..9afaaef93cd83f6703d9c8609a525524957ad6b3 100755 --- a/distiller/pruning/automated_gradual_pruner.py +++ b/distiller/pruning/automated_gradual_pruner.py @@ -105,6 +105,13 @@ class L1RankedStructureParameterPruner_AGP(StructuredAGP): group_dependency=group_dependency, kwargs=kwargs) +class L2RankedStructureParameterPruner_AGP(StructuredAGP): + def __init__(self, name, initial_sparsity, final_sparsity, group_type, weights, group_dependency=None, kwargs=None): + super().__init__(name, initial_sparsity, final_sparsity) + self.pruner = L2RankedStructureParameterPruner(name, group_type, desired_sparsity=0, weights=weights, + group_dependency=group_dependency, kwargs=kwargs) + + class ActivationAPoZRankedFilterPruner_AGP(StructuredAGP): def __init__(self, name, initial_sparsity, final_sparsity, group_type, weights, group_dependency=None): assert group_type in ['3D', 'Filters'] diff --git a/distiller/pruning/ranked_structures_pruner.py b/distiller/pruning/ranked_structures_pruner.py index 3af23c19deb7e99d706d4d9c8aa5ad23f4b7b74b..1a23c620dc9ea432fe0ef4658e877cb4bb670456 100755 --- a/distiller/pruning/ranked_structures_pruner.py +++ b/distiller/pruning/ranked_structures_pruner.py @@ -83,9 +83,12 @@ l2_magnitude = partial(torch.norm, p=2) class LpRankedStructureParameterPruner(RankedStructureParameterPruner): - """Uses mean L1-norm to rank and prune structures. + """Uses Lp-norm to rank and prune structures. - This class prunes to a prescribed percentage of structured-sparsity (level pruning). + This class prunes to a prescribed percentage of structured-sparsity (level pruning), by + first ranking (sorting) the structures based on their Lp-norm, and then pruning a perctenage + of the lower-ranking structures. + See also: https://en.wikipedia.org/wiki/Lp_space#The_p-norm_in_finite_dimensions """ def __init__(self, name, group_type, desired_sparsity, weights, group_dependency=None, kwargs=None, magnitude_fn=None):