From 2179ec50d2b586d33997c6a1f48fc10ebbf497a8 Mon Sep 17 00:00:00 2001 From: Neta Zmora <neta.zmora@intel.com> Date: Tue, 5 Feb 2019 16:34:09 +0200 Subject: [PATCH] Filter ranking: add support for ranking by L2 magnitude --- distiller/pruning/ranked_structures_pruner.py | 95 ++++++++++++------- distiller/thresholding.py | 10 +- 2 files changed, 70 insertions(+), 35 deletions(-) diff --git a/distiller/pruning/ranked_structures_pruner.py b/distiller/pruning/ranked_structures_pruner.py index ca1a4a1..3af23c1 100755 --- a/distiller/pruning/ranked_structures_pruner.py +++ b/distiller/pruning/ranked_structures_pruner.py @@ -78,17 +78,25 @@ class RankedStructureParameterPruner(_ParameterPruner): raise NotImplementedError -class L1RankedStructureParameterPruner(RankedStructureParameterPruner): +l1_magnitude = partial(torch.norm, p=1) +l2_magnitude = partial(torch.norm, p=2) + + +class LpRankedStructureParameterPruner(RankedStructureParameterPruner): """Uses mean L1-norm to rank and prune structures. This class prunes to a prescribed percentage of structured-sparsity (level pruning). """ - def __init__(self, name, group_type, desired_sparsity, weights, group_dependency=None, kwargs=None): + def __init__(self, name, group_type, desired_sparsity, weights, + group_dependency=None, kwargs=None, magnitude_fn=None): super().__init__(name, group_type, desired_sparsity, weights, group_dependency) if group_type not in ['3D', 'Filters', 'Channels', 'Rows', 'Blocks']: raise ValueError("Structure {} was requested but " "currently ranking of this shape is not supported". format(group_type)) + assert magnitude_fn is not None + self.magnitude_fn = magnitude_fn + if group_type == 'Blocks': try: self.block_shape = kwargs['block_shape'] @@ -101,18 +109,19 @@ class L1RankedStructureParameterPruner(RankedStructureParameterPruner): if self.group_type in ['3D', 'Filters']: group_pruning_fn = self.rank_and_prune_filters elif self.group_type == 'Channels': - group_pruning_fn = self.rank_and_prune_channels + group_pruning_fn = partial(self.rank_and_prune_channels) elif self.group_type == 'Rows': group_pruning_fn = self.rank_and_prune_rows elif self.group_type == 'Blocks': group_pruning_fn = partial(self.rank_and_prune_blocks, block_shape=self.block_shape) - binary_map = group_pruning_fn(fraction_to_prune, param, param_name, zeros_mask_dict, model, binary_map) + binary_map = group_pruning_fn(fraction_to_prune, param, param_name, + zeros_mask_dict, model, binary_map, self.magnitude_fn) return binary_map @staticmethod def rank_and_prune_channels(fraction_to_prune, param, param_name=None, - zeros_mask_dict=None, model=None, binary_map=None): + zeros_mask_dict=None, model=None, binary_map=None, magnitude_fn=l1_magnitude): def rank_channels(fraction_to_prune, param): num_filters = param.size(0) num_channels = param.size(1) @@ -122,9 +131,9 @@ class L1RankedStructureParameterPruner(RankedStructureParameterPruner): # tensor, is now a row in the 2D tensor. view_2d = param.view(-1, kernel_size) # Next, compute the sums of each kernel - kernel_sums = view_2d.abs().sum(dim=1) + kernel_mags = magnitude_fn(view_2d, dim=1) # Now group by channels - k_sums_mat = kernel_sums.view(num_filters, num_channels).t() + k_sums_mat = kernel_mags.view(num_filters, num_channels).t() channel_mags = k_sums_mat.mean(dim=1) k = int(fraction_to_prune * channel_mags.size(0)) if k == 0: @@ -160,14 +169,14 @@ class L1RankedStructureParameterPruner(RankedStructureParameterPruner): @staticmethod def rank_and_prune_filters(fraction_to_prune, param, param_name, - zeros_mask_dict, model=None, binary_map=None): + zeros_mask_dict, model=None, binary_map=None, magnitude_fn=l1_magnitude): assert param.dim() == 4, "This thresholding is only supported for 4D weights" threshold = None if binary_map is None: # First we rank the filters view_filters = param.view(param.size(0), -1) - filter_mags = view_filters.data.abs().mean(dim=1) + filter_mags = magnitude_fn(view_filters, dim=1) topk_filters = int(fraction_to_prune * filter_mags.size(0)) if topk_filters == 0: msglogger.info("Too few filters - can't prune %.1f%% filters", 100*fraction_to_prune) @@ -178,7 +187,8 @@ class L1RankedStructureParameterPruner(RankedStructureParameterPruner): param_name, topk_filters, filter_mags.size(0)) # Then we threshold - mask, binary_map = distiller.group_threshold_mask(param, 'Filters', threshold, 'Mean_Abs', binary_map) + threshold_type = 'L1' if magnitude_fn == l1_magnitude else 'L2' + mask, binary_map = distiller.group_threshold_mask(param, 'Filters', threshold, threshold_type, binary_map) if zeros_mask_dict is not None: zeros_mask_dict[param_name].mask = mask msglogger.info("L1RankedStructureParameterPruner - param: %s pruned=%.3f goal=%.3f", @@ -189,7 +199,7 @@ class L1RankedStructureParameterPruner(RankedStructureParameterPruner): @staticmethod def rank_and_prune_rows(fraction_to_prune, param, param_name, - zeros_mask_dict, model=None, binary_map=None): + zeros_mask_dict, model=None, binary_map=None, magnitude_fn=l1_magnitude): """Prune the rows of a matrix, based on ranked L1-norms of the matrix rows. PyTorch stores the weights matrices in a transposed format. I.e. before performing GEMM, a matrix is @@ -203,21 +213,23 @@ class L1RankedStructureParameterPruner(RankedStructureParameterPruner): assert param.dim() == 2, "This thresholding is only supported for 2D weights" ROWS_DIM = 0 THRESHOLD_DIM = 'Cols' - rows_mags = param.abs().mean(dim=ROWS_DIM) + rows_mags = magnitude_fn(param, dim=ROWS_DIM) num_rows_to_prune = int(fraction_to_prune * rows_mags.size(0)) if num_rows_to_prune == 0: msglogger.info("Too few filters - can't prune %.1f%% rows", 100*fraction_to_prune) return bottomk_rows, _ = torch.topk(rows_mags, num_rows_to_prune, largest=False, sorted=True) threshold = bottomk_rows[-1] - zeros_mask_dict[param_name].mask = distiller.group_threshold_mask(param, THRESHOLD_DIM, threshold, 'Mean_Abs') + threshold_type = 'L1' if magnitude_fn == l1_magnitude else 'L2' + zeros_mask_dict[param_name].mask = distiller.group_threshold_mask(param, THRESHOLD_DIM, + threshold, threshold_type) msglogger.info("L1RankedStructureParameterPruner - param: %s pruned=%.3f goal=%.3f (%d/%d)", param_name, distiller.sparsity(zeros_mask_dict[param_name].mask), fraction_to_prune, num_rows_to_prune, rows_mags.size(0)) @staticmethod - def rank_and_prune_blocks(fraction_to_prune, param, param_name=None, - zeros_mask_dict=None, model=None, binary_map=None, block_shape=None): + def rank_and_prune_blocks(fraction_to_prune, param, param_name=None, zeros_mask_dict=None, + model=None, binary_map=None, block_shape=None, magnitude_fn=l1_magnitude): """Block-wise pruning for 4D tensors. The block shape is specified using a tuple: [block_repetitions, block_depth, block_height, block_width]. @@ -251,26 +263,20 @@ class L1RankedStructureParameterPruner(RankedStructureParameterPruner): kernel_size = param.size(2) * param.size(3) if block_depth > 1: - view_dims = ( - num_filters*num_channels//(block_repetitions*block_depth), - block_repetitions*block_depth, - kernel_size, - ) + view_dims = (num_filters*num_channels//(block_repetitions*block_depth), + block_repetitions*block_depth, + kernel_size,) else: - view_dims = ( - num_filters // block_repetitions, - block_repetitions, - -1, - ) + view_dims = (num_filters // block_repetitions, + block_repetitions, + -1,) def rank_blocks(fraction_to_prune, param): # Create a view where each block is a column view1 = param.view(*view_dims) # Next, compute the sums of each column (block) - block_sums = view1.abs().sum(dim=1) - - # Now group by channels - block_mags = block_sums.view(-1) # flatten + block_mags = magnitude_fn(view1, dim=1) + block_mags = block_mags.view(-1) # flatten k = int(fraction_to_prune * block_mags.size(0)) if k == 0: msglogger.info("Too few blocks (%d)- can't prune %.1f%% blocks", @@ -302,6 +308,28 @@ class L1RankedStructureParameterPruner(RankedStructureParameterPruner): return binary_map +class L1RankedStructureParameterPruner(LpRankedStructureParameterPruner): + """Uses mean L1-norm to rank and prune structures. + + This class prunes to a prescribed percentage of structured-sparsity (level pruning). + """ + def __init__(self, name, group_type, desired_sparsity, weights, + group_dependency=None, kwargs=None): + super().__init__(name, group_type, desired_sparsity, weights, + group_dependency, kwargs, magnitude_fn=l1_magnitude) + + +class L2RankedStructureParameterPruner(LpRankedStructureParameterPruner): + """Uses mean L2-norm to rank and prune structures. + + This class prunes to a prescribed percentage of structured-sparsity (level pruning). + """ + def __init__(self, name, group_type, desired_sparsity, weights, + group_dependency=None, kwargs=None): + super().__init__(name, group_type, desired_sparsity, weights, + group_dependency, kwargs, magnitude_fn=l2_magnitude) + + def mask_from_filter_order(filters_ordered_by_criterion, param, num_filters, binary_map): if binary_map is None: binary_map = torch.zeros(num_filters).cuda() @@ -324,7 +352,8 @@ class ActivationAPoZRankedFilterPruner(RankedStructureParameterPruner): def prune_group(self, fraction_to_prune, param, param_name, zeros_mask_dict, model=None, binary_map=None): if fraction_to_prune == 0: return - binary_map = self.rank_and_prune_filters(fraction_to_prune, param, param_name, zeros_mask_dict, model, binary_map) + binary_map = self.rank_and_prune_filters(fraction_to_prune, param, param_name, + zeros_mask_dict, model, binary_map) return binary_map def rank_and_prune_filters(self, fraction_to_prune, param, param_name, zeros_mask_dict, model, binary_map=None): @@ -370,7 +399,8 @@ class RandomRankedFilterPruner(RankedStructureParameterPruner): def prune_group(self, fraction_to_prune, param, param_name, zeros_mask_dict, model=None, binary_map=None): if fraction_to_prune == 0: return - binary_map = self.rank_and_prune_filters(fraction_to_prune, param, param_name, zeros_mask_dict, model, binary_map) + binary_map = self.rank_and_prune_filters(fraction_to_prune, param, param_name, + zeros_mask_dict, model, binary_map) return binary_map def rank_and_prune_filters(self, fraction_to_prune, param, param_name, zeros_mask_dict, model, binary_map=None): @@ -402,7 +432,8 @@ class GradientRankedFilterPruner(RankedStructureParameterPruner): def prune_group(self, fraction_to_prune, param, param_name, zeros_mask_dict, model=None, binary_map=None): if fraction_to_prune == 0: return - binary_map = self.rank_and_prune_filters(fraction_to_prune, param, param_name, zeros_mask_dict, model, binary_map) + binary_map = self.rank_and_prune_filters(fraction_to_prune, param, param_name, + zeros_mask_dict, model, binary_map) return binary_map def rank_and_prune_filters(self, fraction_to_prune, param, param_name, zeros_mask_dict, model, binary_map=None): diff --git a/distiller/thresholding.py b/distiller/thresholding.py index be03a27..4f0b845 100755 --- a/distiller/thresholding.py +++ b/distiller/thresholding.py @@ -151,7 +151,7 @@ def group_threshold_mask(param, group_type, threshold, threshold_criteria, binar if binary_map is None: binary_map = group_threshold_binary_map(param, group_type, threshold, threshold_criteria) a = binary_map.expand(param.size(1) * param.size(2) * param.size(3), param.size(0)).t() - return a.view(param.size(0), param.size(1), param.size(2), param.size(3)), binary_map + return a.view(*param.shape), binary_map elif group_type == '4D': assert param.dim() == 4, "This thresholding is only supported for 4D weights" @@ -181,10 +181,14 @@ def group_threshold_mask(param, group_type, threshold, threshold_criteria, binar def threshold_policy(weights, thresholds, threshold_criteria, dim=1): """ """ - if threshold_criteria == 'Mean_Abs': - return weights.data.abs().mean(dim=dim).gt(thresholds).type(weights.type()) + if threshold_criteria in ['Mean_Abs', 'Mean_L1']: + return weights.data.norm(p=1, dim=dim).div(weights.size(dim)).gt(thresholds).type(weights.type()) + if threshold_criteria == 'Mean_L2': + return weights.data.norm(p=2, dim=dim).div(weights.size(dim)).gt(thresholds).type(weights.type()) elif threshold_criteria == 'L1': return weights.data.norm(p=1, dim=dim).gt(thresholds).type(weights.type()) + elif threshold_criteria == 'L2': + return weights.data.norm(p=2, dim=dim).gt(thresholds).type(weights.type()) elif threshold_criteria == 'Max': maxv, _ = weights.data.abs().max(dim=dim) return maxv.gt(thresholds).type(weights.type()) -- GitLab