diff --git a/distiller/pruning/ranked_structures_pruner.py b/distiller/pruning/ranked_structures_pruner.py index 4816410c900b912ceb64852edac7b14a33cbbf8b..b4a667cb931328fc3e022b79c856213028e9c891 100755 --- a/distiller/pruning/ranked_structures_pruner.py +++ b/distiller/pruning/ranked_structures_pruner.py @@ -20,6 +20,7 @@ import distiller from .pruner import _ParameterPruner msglogger = logging.getLogger() + # TODO: support different policies for ranking structures class L1RankedStructureParameterPruner(_ParameterPruner): """Uses mean L1-norm to rank structures and prune a specified percentage of structures @@ -94,7 +95,7 @@ class L1RankedStructureParameterPruner(_ParameterPruner): def rank_prune_filters(self, fraction_to_prune, param, param_name, zeros_mask_dict): assert param.dim() == 4, "This thresholding is only supported for 4D weights" view_filters = param.view(param.size(0), -1) - filter_mags = view_filters.data.abs().mean(dim=1) + filter_mags = view_filters.data.norm(1, dim=1) # same as view_filters.data.abs().sum(dim=1) topk_filters = int(fraction_to_prune * filter_mags.size(0)) if topk_filters == 0: msglogger.info("Too few filters - can't prune %.1f%% filters", 100*fraction_to_prune)