Skip to content
Snippets Groups Projects
Commit fd544c7b authored by Benoit Brummer's avatar Benoit Brummer Committed by Neta Zmora
Browse files

LpRankedStructureParameterPruner: work with 3D filters (ie Conv1d) (#348)

* LpRankedStructureParameterPruner: work with 3D filters (ie Conv1d)
parent d4cbab3f
No related branches found
No related tags found
No related merge requests found
......@@ -220,7 +220,7 @@ class LpRankedStructureParameterPruner(_RankedStructureParameterPruner):
def rank_and_prune_filters(fraction_to_prune, param, param_name, zeros_mask_dict,
model=None, binary_map=None, magnitude_fn=l1_magnitude,
noise=0.0, group_size=1, rounding_fn=math.floor):
assert param.dim() == 4, "This pruning is only supported for 4D weights"
assert param.dim() == 4 or param.dim() == 3, "This pruning is only supported for 3D and 4D weights"
threshold = None
threshold_type = 'L1' if magnitude_fn == l1_magnitude else 'L2'
......
......@@ -19,6 +19,7 @@
The code below supports fine-grained tensor thresholding and group-wise thresholding.
"""
import torch
import numpy as np
def threshold_mask(weights, threshold):
......@@ -85,7 +86,7 @@ def group_threshold_binary_map(param, group_type, threshold, threshold_criteria)
return binary_map
elif group_type == '3D' or group_type == 'Filters':
assert param.dim() == 4, "This thresholding is only supported for 4D weights"
assert param.dim() == 4 or param.dim() == 3, "This pruning is only supported for 3D and 4D weights"
view_filters = param.view(param.size(0), -1)
thresholds = torch.Tensor([threshold] * param.size(0)).to(param.device)
binary_map = threshold_policy(view_filters, thresholds, threshold_criteria)
......@@ -153,7 +154,7 @@ def group_threshold_mask(param, group_type, threshold, threshold_criteria, binar
elif group_type == '3D' or group_type == 'Filters':
if binary_map is None:
binary_map = group_threshold_binary_map(param, group_type, threshold, threshold_criteria)
a = binary_map.expand(param.size(1) * param.size(2) * param.size(3), param.size(0)).t()
a = binary_map.expand(np.prod(param.shape[1:]), param.size(0)).t()
return a.view(*param.shape), binary_map
elif group_type == '4D':
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment