From 90226b1cd38d2210e6fb7da988e28ce721b65804 Mon Sep 17 00:00:00 2001 From: Neta Zmora <neta.zmora@intel.com> Date: Thu, 21 Mar 2019 12:44:53 +0200 Subject: [PATCH] Added BernoulliFilterPruner_AGP This is AGP (automatic gradual pruning) for a pruner which samples filters-to-prune by sampling a Bernoulli probability distribution. --- distiller/pruning/__init__.py | 9 ++++++--- distiller/pruning/automated_gradual_pruner.py | 8 ++++++++ distiller/pruning/ranked_structures_pruner.py | 2 ++ 3 files changed, 16 insertions(+), 3 deletions(-) diff --git a/distiller/pruning/__init__.py b/distiller/pruning/__init__.py index 960177c..abc2525 100755 --- a/distiller/pruning/__init__.py +++ b/distiller/pruning/__init__.py @@ -25,7 +25,8 @@ from .automated_gradual_pruner import AutomatedGradualPruner, \ ActivationAPoZRankedFilterPruner_AGP, \ ActivationMeanRankedFilterPruner_AGP, \ GradientRankedFilterPruner_AGP, \ - RandomRankedFilterPruner_AGP + RandomRankedFilterPruner_AGP, \ + BernoulliFilterPruner_AGP from .level_pruner import SparsityLevelParameterPruner from .sensitivity_pruner import SensitivityPruner from .splicing_pruner import SplicingPruner @@ -34,8 +35,10 @@ from .ranked_structures_pruner import L1RankedStructureParameterPruner, \ L2RankedStructureParameterPruner, \ ActivationAPoZRankedFilterPruner, \ ActivationMeanRankedFilterPruner, \ - GradientRankedFilterPruner, \ - RandomRankedFilterPruner + GradientRankedFilterPruner, \ + RandomRankedFilterPruner, \ + RandomLevelStructureParameterPruner, \ + BernoulliFilterPruner from .baidu_rnn_pruner import BaiduRNNPruner from .greedy_filter_pruning import greedy_pruner diff --git a/distiller/pruning/automated_gradual_pruner.py b/distiller/pruning/automated_gradual_pruner.py index b2c275d..51a9abb 100755 --- a/distiller/pruning/automated_gradual_pruner.py +++ b/distiller/pruning/automated_gradual_pruner.py @@ -141,3 +141,11 @@ class RandomRankedFilterPruner_AGP(StructuredAGP): super().__init__(name, initial_sparsity, final_sparsity) self.pruner = RandomRankedFilterPruner(name, group_type, desired_sparsity=0, weights=weights, group_dependency=group_dependency) + + +class BernoulliFilterPruner_AGP(StructuredAGP): + def __init__(self, name, initial_sparsity, final_sparsity, group_type, weights, group_dependency=None): + assert group_type in ['3D', 'Filters'] + super().__init__(name, initial_sparsity, final_sparsity) + self.pruner = BernoulliFilterPruner(name, group_type, desired_sparsity=0, + weights=weights, group_dependency=group_dependency) diff --git a/distiller/pruning/ranked_structures_pruner.py b/distiller/pruning/ranked_structures_pruner.py index a38c87c..f813b21 100755 --- a/distiller/pruning/ranked_structures_pruner.py +++ b/distiller/pruning/ranked_structures_pruner.py @@ -365,6 +365,7 @@ def mask_from_filter_order(filters_ordered_by_criterion, param, num_filters, bin if binary_map is None: binary_map = torch.zeros(num_filters).cuda() binary_map[filters_ordered_by_criterion] = 1 + binary_map = binary_map.detach() expanded = binary_map.expand(param.size(1) * param.size(2) * param.size(3), param.size(0)).t().contiguous() return expanded.view(param.shape), binary_map @@ -509,6 +510,7 @@ class BernoulliFilterPruner(RankedStructureParameterPruner): # Compensate for dropping filters pruning_factor = binary_map.sum() / num_filters mask.div_(pruning_factor) + zeros_mask_dict[param_name].mask = mask msglogger.debug("BernoulliFilterPruner - param: %s pruned=%.3f goal=%.3f (%d/%d)", param_name, -- GitLab