diff --git a/distiller/pruning/ranked_structures_pruner.py b/distiller/pruning/ranked_structures_pruner.py
index 7b47e8e75d5feed97670dc742b3ab43f39b5442c..56454e7911d03cf60189c7695160401eaef20ca9 100755
--- a/distiller/pruning/ranked_structures_pruner.py
+++ b/distiller/pruning/ranked_structures_pruner.py
@@ -220,7 +220,7 @@ class LpRankedStructureParameterPruner(_RankedStructureParameterPruner):
     def rank_and_prune_filters(fraction_to_prune, param, param_name, zeros_mask_dict, 
                                model=None, binary_map=None, magnitude_fn=l1_magnitude, 
                                noise=0.0, group_size=1, rounding_fn=math.floor):
-        assert param.dim() == 4, "This pruning is only supported for 4D weights"
+        assert param.dim() == 4 or param.dim() == 3, "This pruning is only supported for 3D and 4D weights"
 
         threshold = None
         threshold_type = 'L1' if magnitude_fn == l1_magnitude else 'L2'
diff --git a/distiller/thresholding.py b/distiller/thresholding.py
index 52761ad58313294255b814010f329fcba2d5ceb6..55be4a132e31c5ac074b61f3eba334868b132a78 100755
--- a/distiller/thresholding.py
+++ b/distiller/thresholding.py
@@ -19,6 +19,7 @@
 The code below supports fine-grained tensor thresholding and group-wise thresholding.
 """
 import torch
+import numpy as np
 
 
 def threshold_mask(weights, threshold):
@@ -85,7 +86,7 @@ def group_threshold_binary_map(param, group_type, threshold, threshold_criteria)
         return binary_map
 
     elif group_type == '3D' or group_type == 'Filters':
-        assert param.dim() == 4, "This thresholding is only supported for 4D weights"
+        assert param.dim() == 4 or param.dim() == 3, "This pruning is only supported for 3D and 4D weights"
         view_filters = param.view(param.size(0), -1)
         thresholds = torch.Tensor([threshold] * param.size(0)).to(param.device)
         binary_map = threshold_policy(view_filters, thresholds, threshold_criteria)
@@ -153,7 +154,7 @@ def group_threshold_mask(param, group_type, threshold, threshold_criteria, binar
     elif group_type == '3D' or group_type == 'Filters':
         if binary_map is None:
             binary_map = group_threshold_binary_map(param, group_type, threshold, threshold_criteria)
-        a = binary_map.expand(param.size(1) * param.size(2) * param.size(3), param.size(0)).t()
+        a = binary_map.expand(np.prod(param.shape[1:]), param.size(0)).t()
         return a.view(*param.shape), binary_map
 
     elif group_type == '4D':