From b6b6f8174ac3307fcd1c4339cb3926abe921ec64 Mon Sep 17 00:00:00 2001 From: Neta Zmora <neta.zmora@intel.com> Date: Sun, 17 Mar 2019 13:49:15 +0200 Subject: [PATCH] Add two stochastic pruners MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit BernoulliFilterPruner – assigns a Bernoulli probability distribution to each of the filters. RandomLevelStructureParameterPruner – assigns a Uniform probability distribution to the level-pruning level used by an L1-norm structure pruner. --- distiller/pruning/ranked_structures_pruner.py | 65 +++++++++++++++++-- 1 file changed, 60 insertions(+), 5 deletions(-) diff --git a/distiller/pruning/ranked_structures_pruner.py b/distiller/pruning/ranked_structures_pruner.py index 8be38b5..2ee79cb 100755 --- a/distiller/pruning/ranked_structures_pruner.py +++ b/distiller/pruning/ranked_structures_pruner.py @@ -18,6 +18,7 @@ from functools import partial import numpy as np import logging import torch +from random import uniform import distiller from .pruner import _ParameterPruner msglogger = logging.getLogger() @@ -174,7 +175,7 @@ class LpRankedStructureParameterPruner(RankedStructureParameterPruner): @staticmethod def rank_and_prune_filters(fraction_to_prune, param, param_name, zeros_mask_dict, model=None, binary_map=None, magnitude_fn=l1_magnitude): - assert param.dim() == 4, "This thresholding is only supported for 4D weights" + assert param.dim() == 4, "This pruning is only supported for 4D weights" threshold = None if binary_map is None: @@ -217,7 +218,7 @@ class LpRankedStructureParameterPruner(RankedStructureParameterPruner): away from each other in memory, and this means poor use of caches and system memory. """ - assert param.dim() == 2, "This thresholding is only supported for 2D weights" + assert param.dim() == 2, "This pruning is only supported for 2D weights" ROWS_DIM = 0 THRESHOLD_DIM = 'Cols' rows_mags = magnitude_fn(param, dim=ROWS_DIM) @@ -338,6 +339,23 @@ class L2RankedStructureParameterPruner(LpRankedStructureParameterPruner): group_dependency, kwargs, magnitude_fn=l2_magnitude) +class RandomLevelStructureParameterPruner(L1RankedStructureParameterPruner): + """Uses mean L1-norm to rank and prune structures, with a random pruning regimen. + + This class sets the pruning level to a random value in the range sparsity_range, + and chooses which structures to remove using L1-norm ranking. + The idea is similiar to DropFilter, but instead of randomly choosing filters, + we randomly choose a sparsity level and then prune according to magnitude. + """ + def __init__(self, name, group_type, sparsity_range, weights, + group_dependency=None, kwargs=None): + self.sparsity_range = sparsity_range + super().__init__(name, group_type, 0, weights, group_dependency, kwargs) + + def fraction_to_prune(self, param_name): + return uniform(self.sparsity_range[0], self.sparsity_range[1]) + + def mask_from_filter_order(filters_ordered_by_criterion, param, num_filters, binary_map): if binary_map is None: binary_map = torch.zeros(num_filters).cuda() @@ -365,7 +383,7 @@ class ActivationRankedFilterPruner(RankedStructureParameterPruner): return binary_map def rank_and_prune_filters(self, fraction_to_prune, param, param_name, zeros_mask_dict, model, binary_map=None): - assert param.dim() == 4, "This thresholding is only supported for 4D weights" + assert param.dim() == 4, "This pruning is only supported for 4D weights" # Use the parameter name to locate the module that has the activation sparsity statistics fq_name = param_name.replace(".conv", ".relu")[:-len(".weight")] @@ -438,7 +456,7 @@ class RandomRankedFilterPruner(RankedStructureParameterPruner): return binary_map def rank_and_prune_filters(self, fraction_to_prune, param, param_name, zeros_mask_dict, model, binary_map=None): - assert param.dim() == 4, "This thresholding is only supported for 4D weights" + assert param.dim() == 4, "This pruning is only supported for 4D weights" num_filters = param.size(0) num_filters_to_prune = int(fraction_to_prune * num_filters) @@ -457,6 +475,43 @@ class RandomRankedFilterPruner(RankedStructureParameterPruner): return binary_map +class BernoulliFilterPruner(RankedStructureParameterPruner): + """A Bernoulli probability for dropping each filter. + + This is can be used for random filter-dropping algorithms (e.g. DropFilter) + """ + def __init__(self, name, group_type, desired_sparsity, weights, group_dependency=None): + super().__init__(name, group_type, desired_sparsity, weights, group_dependency) + + def prune_group(self, fraction_to_prune, param, param_name, zeros_mask_dict, model=None, binary_map=None): + if fraction_to_prune == 0: + return + binary_map = self.rank_and_prune_filters(fraction_to_prune, param, param_name, + zeros_mask_dict, model, binary_map) + return binary_map + + def rank_and_prune_filters(self, fraction_to_prune, param, param_name, zeros_mask_dict, model, binary_map=None): + assert param.dim() == 4, "This pruner is only supported for 4D weights" + num_filters = param.size(0) + num_filters_to_prune = int(fraction_to_prune * num_filters) + + keep_prob = 1 - fraction_to_prune + if binary_map is None: + binary_map = torch.bernoulli(torch.as_tensor([keep_prob] * num_filters)) + mask, _ = mask_from_filter_order(None, param, num_filters, binary_map) + # mask = mask.detach() + mask = mask.to(param.device) + # Compensate for dropping filters + pruning_factor = binary_map.sum() / num_filters + mask.div_(pruning_factor) + zeros_mask_dict[param_name].mask = mask + msglogger.debug("BernoulliFilterPruner - param: %s pruned=%.3f goal=%.3f (%d/%d)", + param_name, + distiller.sparsity_3D(zeros_mask_dict[param_name].mask), + fraction_to_prune, num_filters_to_prune, num_filters) + return binary_map + + class GradientRankedFilterPruner(RankedStructureParameterPruner): """ """ @@ -471,7 +526,7 @@ class GradientRankedFilterPruner(RankedStructureParameterPruner): return binary_map def rank_and_prune_filters(self, fraction_to_prune, param, param_name, zeros_mask_dict, model, binary_map=None): - assert param.dim() == 4, "This thresholding is only supported for 4D weights" + assert param.dim() == 4, "This pruning is only supported for 4D weights" if param.grad is None: msglogger.info("Skipping gradient pruning of %s because it does not have a gradient yet", param_name) return -- GitLab