From ba1ee25b90b6fa6d85763d04c34f087f2016f986 Mon Sep 17 00:00:00 2001 From: Bar <elhararb@gmail.com> Date: Sun, 17 Mar 2019 14:05:26 +0200 Subject: [PATCH] Fix logging message for structured pruning (#194) Modify LpRankedStructureParameterPruner to log as the correct class name for both L1 and L2 pruners. --- distiller/pruning/ranked_structures_pruner.py | 21 ++++++++++++------- 1 file changed, 13 insertions(+), 8 deletions(-) diff --git a/distiller/pruning/ranked_structures_pruner.py b/distiller/pruning/ranked_structures_pruner.py index 2ee79cb..a38c87c 100755 --- a/distiller/pruning/ranked_structures_pruner.py +++ b/distiller/pruning/ranked_structures_pruner.py @@ -165,9 +165,11 @@ class LpRankedStructureParameterPruner(RankedStructureParameterPruner): threshold = bottomk_channels[-1] binary_map = channel_mags.gt(threshold).type(param.data.type()) + threshold_type = 'L1' if magnitude_fn == l1_magnitude else 'L2' if zeros_mask_dict is not None: zeros_mask_dict[param_name].mask = binary_map_to_mask(binary_map, param) - msglogger.info("L1RankedStructureParameterPruner - param: %s pruned=%.3f goal=%.3f (%d/%d)", param_name, + msglogger.info("%sRankedStructureParameterPruner - param: %s pruned=%.3f goal=%.3f (%d/%d)", + threshold_type, param_name, distiller.sparsity_ch(zeros_mask_dict[param_name].mask), fraction_to_prune, binary_map.sum().item(), param.size(1)) return binary_map @@ -178,6 +180,7 @@ class LpRankedStructureParameterPruner(RankedStructureParameterPruner): assert param.dim() == 4, "This pruning is only supported for 4D weights" threshold = None + threshold_type = 'L1' if magnitude_fn == l1_magnitude else 'L2' if binary_map is None: # First we rank the filters view_filters = param.view(param.size(0), -1) @@ -188,16 +191,15 @@ class LpRankedStructureParameterPruner(RankedStructureParameterPruner): return bottomk, _ = torch.topk(filter_mags, topk_filters, largest=False, sorted=True) threshold = bottomk[-1] - msglogger.info("L1RankedStructureParameterPruner - param: %s pruned=(%d/%d)", - param_name, + msglogger.info("%sRankedStructureParameterPruner - param: %s pruned=(%d/%d)", + threshold_type, param_name, topk_filters, filter_mags.size(0)) # Then we threshold - threshold_type = 'L1' if magnitude_fn == l1_magnitude else 'L2' mask, binary_map = distiller.group_threshold_mask(param, 'Filters', threshold, threshold_type, binary_map) if zeros_mask_dict is not None: zeros_mask_dict[param_name].mask = mask - msglogger.info("L1RankedStructureParameterPruner - param: %s pruned=%.3f goal=%.3f", - param_name, + msglogger.info("%sRankedStructureParameterPruner - param: %s pruned=%.3f goal=%.3f", + threshold_type, param_name, distiller.sparsity(mask), fraction_to_prune) @@ -231,7 +233,8 @@ class LpRankedStructureParameterPruner(RankedStructureParameterPruner): threshold_type = 'L1' if magnitude_fn == l1_magnitude else 'L2' zeros_mask_dict[param_name].mask, binary_map = distiller.group_threshold_mask(param, THRESHOLD_DIM, threshold, threshold_type) - msglogger.info("L1RankedStructureParameterPruner - param: %s pruned=%.3f goal=%.3f (%d/%d)", param_name, + msglogger.info("%sRankedStructureParameterPruner - param: %s pruned=%.3f goal=%.3f (%d/%d)", + threshold_type, param_name, distiller.sparsity(zeros_mask_dict[param_name].mask), fraction_to_prune, num_rows_to_prune, rows_mags.size(0)) return binary_map @@ -309,9 +312,11 @@ class LpRankedStructureParameterPruner(RankedStructureParameterPruner): threshold = bottomk_blocks[-1] binary_map = block_mags.gt(threshold).type(param.data.type()) + threshold_type = 'L1' if magnitude_fn == l1_magnitude else 'L2' if zeros_mask_dict is not None: zeros_mask_dict[param_name].mask = binary_map_to_mask(binary_map, param) - msglogger.info("L1RankedStructureParameterPruner - param: %s pruned=%.3f goal=%.3f (%d/%d)", param_name, + msglogger.info("%sRankedStructureParameterPruner - param: %s pruned=%.3f goal=%.3f (%d/%d)", + threshold_type, param_name, distiller.sparsity_blocks(zeros_mask_dict[param_name].mask, block_shape=block_shape), fraction_to_prune, binary_map.sum().item(), num_super_blocks) return binary_map -- GitLab