diff --git a/distiller/pruning/ranked_structures_pruner.py b/distiller/pruning/ranked_structures_pruner.py index 7b47e8e75d5feed97670dc742b3ab43f39b5442c..56454e7911d03cf60189c7695160401eaef20ca9 100755 --- a/distiller/pruning/ranked_structures_pruner.py +++ b/distiller/pruning/ranked_structures_pruner.py @@ -220,7 +220,7 @@ class LpRankedStructureParameterPruner(_RankedStructureParameterPruner): def rank_and_prune_filters(fraction_to_prune, param, param_name, zeros_mask_dict, model=None, binary_map=None, magnitude_fn=l1_magnitude, noise=0.0, group_size=1, rounding_fn=math.floor): - assert param.dim() == 4, "This pruning is only supported for 4D weights" + assert param.dim() == 4 or param.dim() == 3, "This pruning is only supported for 3D and 4D weights" threshold = None threshold_type = 'L1' if magnitude_fn == l1_magnitude else 'L2' diff --git a/distiller/thresholding.py b/distiller/thresholding.py index 52761ad58313294255b814010f329fcba2d5ceb6..55be4a132e31c5ac074b61f3eba334868b132a78 100755 --- a/distiller/thresholding.py +++ b/distiller/thresholding.py @@ -19,6 +19,7 @@ The code below supports fine-grained tensor thresholding and group-wise thresholding. """ import torch +import numpy as np def threshold_mask(weights, threshold): @@ -85,7 +86,7 @@ def group_threshold_binary_map(param, group_type, threshold, threshold_criteria) return binary_map elif group_type == '3D' or group_type == 'Filters': - assert param.dim() == 4, "This thresholding is only supported for 4D weights" + assert param.dim() == 4 or param.dim() == 3, "This pruning is only supported for 3D and 4D weights" view_filters = param.view(param.size(0), -1) thresholds = torch.Tensor([threshold] * param.size(0)).to(param.device) binary_map = threshold_policy(view_filters, thresholds, threshold_criteria) @@ -153,7 +154,7 @@ def group_threshold_mask(param, group_type, threshold, threshold_criteria, binar elif group_type == '3D' or group_type == 'Filters': if binary_map is None: binary_map = group_threshold_binary_map(param, group_type, threshold, threshold_criteria) - a = binary_map.expand(param.size(1) * param.size(2) * param.size(3), param.size(0)).t() + a = binary_map.expand(np.prod(param.shape[1:]), param.size(0)).t() return a.view(*param.shape), binary_map elif group_type == '4D':