diff --git a/distiller/pruning/__init__.py b/distiller/pruning/__init__.py index 960177c5c2438365ccae757ce30f6c077150d3cf..abc2525d6b4ffd1baecb4e3d45a951eff03be25a 100755 --- a/distiller/pruning/__init__.py +++ b/distiller/pruning/__init__.py @@ -25,7 +25,8 @@ from .automated_gradual_pruner import AutomatedGradualPruner, \ ActivationAPoZRankedFilterPruner_AGP, \ ActivationMeanRankedFilterPruner_AGP, \ GradientRankedFilterPruner_AGP, \ - RandomRankedFilterPruner_AGP + RandomRankedFilterPruner_AGP, \ + BernoulliFilterPruner_AGP from .level_pruner import SparsityLevelParameterPruner from .sensitivity_pruner import SensitivityPruner from .splicing_pruner import SplicingPruner @@ -34,8 +35,10 @@ from .ranked_structures_pruner import L1RankedStructureParameterPruner, \ L2RankedStructureParameterPruner, \ ActivationAPoZRankedFilterPruner, \ ActivationMeanRankedFilterPruner, \ - GradientRankedFilterPruner, \ - RandomRankedFilterPruner + GradientRankedFilterPruner, \ + RandomRankedFilterPruner, \ + RandomLevelStructureParameterPruner, \ + BernoulliFilterPruner from .baidu_rnn_pruner import BaiduRNNPruner from .greedy_filter_pruning import greedy_pruner diff --git a/distiller/pruning/automated_gradual_pruner.py b/distiller/pruning/automated_gradual_pruner.py index b2c275d023ada2ded66db4f2836e004937e686ed..51a9abb5ef2b0f48a289d91768cff43bfac17143 100755 --- a/distiller/pruning/automated_gradual_pruner.py +++ b/distiller/pruning/automated_gradual_pruner.py @@ -141,3 +141,11 @@ class RandomRankedFilterPruner_AGP(StructuredAGP): super().__init__(name, initial_sparsity, final_sparsity) self.pruner = RandomRankedFilterPruner(name, group_type, desired_sparsity=0, weights=weights, group_dependency=group_dependency) + + +class BernoulliFilterPruner_AGP(StructuredAGP): + def __init__(self, name, initial_sparsity, final_sparsity, group_type, weights, group_dependency=None): + assert group_type in ['3D', 'Filters'] + super().__init__(name, initial_sparsity, final_sparsity) + self.pruner = BernoulliFilterPruner(name, group_type, desired_sparsity=0, + weights=weights, group_dependency=group_dependency) diff --git a/distiller/pruning/ranked_structures_pruner.py b/distiller/pruning/ranked_structures_pruner.py index a38c87c23449fa4e821d5b45990a75032a0063fe..f813b211beb626fa195941ea5e8277dfeea61ade 100755 --- a/distiller/pruning/ranked_structures_pruner.py +++ b/distiller/pruning/ranked_structures_pruner.py @@ -365,6 +365,7 @@ def mask_from_filter_order(filters_ordered_by_criterion, param, num_filters, bin if binary_map is None: binary_map = torch.zeros(num_filters).cuda() binary_map[filters_ordered_by_criterion] = 1 + binary_map = binary_map.detach() expanded = binary_map.expand(param.size(1) * param.size(2) * param.size(3), param.size(0)).t().contiguous() return expanded.view(param.shape), binary_map @@ -509,6 +510,7 @@ class BernoulliFilterPruner(RankedStructureParameterPruner): # Compensate for dropping filters pruning_factor = binary_map.sum() / num_filters mask.div_(pruning_factor) + zeros_mask_dict[param_name].mask = mask msglogger.debug("BernoulliFilterPruner - param: %s pruned=%.3f goal=%.3f (%d/%d)", param_name,