Skip to content
Snippets Groups Projects
Commit d75ba51e authored by Neta Zmora's avatar Neta Zmora
Browse files

Fix row-pruning

A recent change requires us to return the binary_map from the
ranking operation, and this was missing for the row-pruning case.
parent cd2c5e73
No related branches found
No related tags found
No related merge requests found
...@@ -228,11 +228,12 @@ class LpRankedStructureParameterPruner(RankedStructureParameterPruner): ...@@ -228,11 +228,12 @@ class LpRankedStructureParameterPruner(RankedStructureParameterPruner):
bottomk_rows, _ = torch.topk(rows_mags, num_rows_to_prune, largest=False, sorted=True) bottomk_rows, _ = torch.topk(rows_mags, num_rows_to_prune, largest=False, sorted=True)
threshold = bottomk_rows[-1] threshold = bottomk_rows[-1]
threshold_type = 'L1' if magnitude_fn == l1_magnitude else 'L2' threshold_type = 'L1' if magnitude_fn == l1_magnitude else 'L2'
zeros_mask_dict[param_name].mask = distiller.group_threshold_mask(param, THRESHOLD_DIM, zeros_mask_dict[param_name].mask, binary_map = distiller.group_threshold_mask(param, THRESHOLD_DIM,
threshold, threshold_type) threshold, threshold_type)
msglogger.info("L1RankedStructureParameterPruner - param: %s pruned=%.3f goal=%.3f (%d/%d)", param_name, msglogger.info("L1RankedStructureParameterPruner - param: %s pruned=%.3f goal=%.3f (%d/%d)", param_name,
distiller.sparsity(zeros_mask_dict[param_name].mask), distiller.sparsity(zeros_mask_dict[param_name].mask),
fraction_to_prune, num_rows_to_prune, rows_mags.size(0)) fraction_to_prune, num_rows_to_prune, rows_mags.size(0))
return binary_map
@staticmethod @staticmethod
def rank_and_prune_blocks(fraction_to_prune, param, param_name=None, zeros_mask_dict=None, def rank_and_prune_blocks(fraction_to_prune, param, param_name=None, zeros_mask_dict=None,
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment