Skip to content
Snippets Groups Projects
Commit 90226b1c authored by Neta Zmora's avatar Neta Zmora
Browse files

Added BernoulliFilterPruner_AGP

This is AGP (automatic gradual pruning) for a pruner which
samples filters-to-prune by sampling a Bernoulli probability distribution.
parent 74a4f7ab
No related branches found
No related tags found
No related merge requests found
......@@ -25,7 +25,8 @@ from .automated_gradual_pruner import AutomatedGradualPruner, \
ActivationAPoZRankedFilterPruner_AGP, \
ActivationMeanRankedFilterPruner_AGP, \
GradientRankedFilterPruner_AGP, \
RandomRankedFilterPruner_AGP
RandomRankedFilterPruner_AGP, \
BernoulliFilterPruner_AGP
from .level_pruner import SparsityLevelParameterPruner
from .sensitivity_pruner import SensitivityPruner
from .splicing_pruner import SplicingPruner
......@@ -34,8 +35,10 @@ from .ranked_structures_pruner import L1RankedStructureParameterPruner, \
L2RankedStructureParameterPruner, \
ActivationAPoZRankedFilterPruner, \
ActivationMeanRankedFilterPruner, \
GradientRankedFilterPruner, \
RandomRankedFilterPruner
GradientRankedFilterPruner, \
RandomRankedFilterPruner, \
RandomLevelStructureParameterPruner, \
BernoulliFilterPruner
from .baidu_rnn_pruner import BaiduRNNPruner
from .greedy_filter_pruning import greedy_pruner
......
......@@ -141,3 +141,11 @@ class RandomRankedFilterPruner_AGP(StructuredAGP):
super().__init__(name, initial_sparsity, final_sparsity)
self.pruner = RandomRankedFilterPruner(name, group_type, desired_sparsity=0,
weights=weights, group_dependency=group_dependency)
class BernoulliFilterPruner_AGP(StructuredAGP):
def __init__(self, name, initial_sparsity, final_sparsity, group_type, weights, group_dependency=None):
assert group_type in ['3D', 'Filters']
super().__init__(name, initial_sparsity, final_sparsity)
self.pruner = BernoulliFilterPruner(name, group_type, desired_sparsity=0,
weights=weights, group_dependency=group_dependency)
......@@ -365,6 +365,7 @@ def mask_from_filter_order(filters_ordered_by_criterion, param, num_filters, bin
if binary_map is None:
binary_map = torch.zeros(num_filters).cuda()
binary_map[filters_ordered_by_criterion] = 1
binary_map = binary_map.detach()
expanded = binary_map.expand(param.size(1) * param.size(2) * param.size(3), param.size(0)).t().contiguous()
return expanded.view(param.shape), binary_map
......@@ -509,6 +510,7 @@ class BernoulliFilterPruner(RankedStructureParameterPruner):
# Compensate for dropping filters
pruning_factor = binary_map.sum() / num_filters
mask.div_(pruning_factor)
zeros_mask_dict[param_name].mask = mask
msglogger.debug("BernoulliFilterPruner - param: %s pruned=%.3f goal=%.3f (%d/%d)",
param_name,
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment