Skip to content
Snippets Groups Projects
Commit ba1ee25b authored by Bar's avatar Bar Committed by Neta Zmora
Browse files

Fix logging message for structured pruning (#194)

Modify LpRankedStructureParameterPruner to log as
the correct class name for both L1 and L2 pruners.
parent 958b0e6d
No related branches found
No related tags found
No related merge requests found
......@@ -165,9 +165,11 @@ class LpRankedStructureParameterPruner(RankedStructureParameterPruner):
threshold = bottomk_channels[-1]
binary_map = channel_mags.gt(threshold).type(param.data.type())
threshold_type = 'L1' if magnitude_fn == l1_magnitude else 'L2'
if zeros_mask_dict is not None:
zeros_mask_dict[param_name].mask = binary_map_to_mask(binary_map, param)
msglogger.info("L1RankedStructureParameterPruner - param: %s pruned=%.3f goal=%.3f (%d/%d)", param_name,
msglogger.info("%sRankedStructureParameterPruner - param: %s pruned=%.3f goal=%.3f (%d/%d)",
threshold_type, param_name,
distiller.sparsity_ch(zeros_mask_dict[param_name].mask),
fraction_to_prune, binary_map.sum().item(), param.size(1))
return binary_map
......@@ -178,6 +180,7 @@ class LpRankedStructureParameterPruner(RankedStructureParameterPruner):
assert param.dim() == 4, "This pruning is only supported for 4D weights"
threshold = None
threshold_type = 'L1' if magnitude_fn == l1_magnitude else 'L2'
if binary_map is None:
# First we rank the filters
view_filters = param.view(param.size(0), -1)
......@@ -188,16 +191,15 @@ class LpRankedStructureParameterPruner(RankedStructureParameterPruner):
return
bottomk, _ = torch.topk(filter_mags, topk_filters, largest=False, sorted=True)
threshold = bottomk[-1]
msglogger.info("L1RankedStructureParameterPruner - param: %s pruned=(%d/%d)",
param_name,
msglogger.info("%sRankedStructureParameterPruner - param: %s pruned=(%d/%d)",
threshold_type, param_name,
topk_filters, filter_mags.size(0))
# Then we threshold
threshold_type = 'L1' if magnitude_fn == l1_magnitude else 'L2'
mask, binary_map = distiller.group_threshold_mask(param, 'Filters', threshold, threshold_type, binary_map)
if zeros_mask_dict is not None:
zeros_mask_dict[param_name].mask = mask
msglogger.info("L1RankedStructureParameterPruner - param: %s pruned=%.3f goal=%.3f",
param_name,
msglogger.info("%sRankedStructureParameterPruner - param: %s pruned=%.3f goal=%.3f",
threshold_type, param_name,
distiller.sparsity(mask),
fraction_to_prune)
......@@ -231,7 +233,8 @@ class LpRankedStructureParameterPruner(RankedStructureParameterPruner):
threshold_type = 'L1' if magnitude_fn == l1_magnitude else 'L2'
zeros_mask_dict[param_name].mask, binary_map = distiller.group_threshold_mask(param, THRESHOLD_DIM,
threshold, threshold_type)
msglogger.info("L1RankedStructureParameterPruner - param: %s pruned=%.3f goal=%.3f (%d/%d)", param_name,
msglogger.info("%sRankedStructureParameterPruner - param: %s pruned=%.3f goal=%.3f (%d/%d)",
threshold_type, param_name,
distiller.sparsity(zeros_mask_dict[param_name].mask),
fraction_to_prune, num_rows_to_prune, rows_mags.size(0))
return binary_map
......@@ -309,9 +312,11 @@ class LpRankedStructureParameterPruner(RankedStructureParameterPruner):
threshold = bottomk_blocks[-1]
binary_map = block_mags.gt(threshold).type(param.data.type())
threshold_type = 'L1' if magnitude_fn == l1_magnitude else 'L2'
if zeros_mask_dict is not None:
zeros_mask_dict[param_name].mask = binary_map_to_mask(binary_map, param)
msglogger.info("L1RankedStructureParameterPruner - param: %s pruned=%.3f goal=%.3f (%d/%d)", param_name,
msglogger.info("%sRankedStructureParameterPruner - param: %s pruned=%.3f goal=%.3f (%d/%d)",
threshold_type, param_name,
distiller.sparsity_blocks(zeros_mask_dict[param_name].mask, block_shape=block_shape),
fraction_to_prune, binary_map.sum().item(), num_super_blocks)
return binary_map
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment