diff --git a/distiller/pruning/ranked_structures_pruner.py b/distiller/pruning/ranked_structures_pruner.py index b5f990a6922a6eb09d017024d4385061c9671bc8..8be38b5c5d79f4447004284be8f20eb9d0506c5b 100755 --- a/distiller/pruning/ranked_structures_pruner.py +++ b/distiller/pruning/ranked_structures_pruner.py @@ -228,11 +228,12 @@ class LpRankedStructureParameterPruner(RankedStructureParameterPruner): bottomk_rows, _ = torch.topk(rows_mags, num_rows_to_prune, largest=False, sorted=True) threshold = bottomk_rows[-1] threshold_type = 'L1' if magnitude_fn == l1_magnitude else 'L2' - zeros_mask_dict[param_name].mask = distiller.group_threshold_mask(param, THRESHOLD_DIM, - threshold, threshold_type) + zeros_mask_dict[param_name].mask, binary_map = distiller.group_threshold_mask(param, THRESHOLD_DIM, + threshold, threshold_type) msglogger.info("L1RankedStructureParameterPruner - param: %s pruned=%.3f goal=%.3f (%d/%d)", param_name, distiller.sparsity(zeros_mask_dict[param_name].mask), fraction_to_prune, num_rows_to_prune, rows_mags.size(0)) + return binary_map @staticmethod def rank_and_prune_blocks(fraction_to_prune, param, param_name=None, zeros_mask_dict=None,