diff --git a/distiller/thresholding.py b/distiller/thresholding.py
index f18899ba5a36039c78b46cf4bf5baabbcfdaaa2c..a00cbebe66c31206802cf68f3bf6e42ed20082db 100755
--- a/distiller/thresholding.py
+++ b/distiller/thresholding.py
@@ -73,7 +73,7 @@ class GroupThresholdMixin(object):
         elif group_type == 'Cols':
             assert param.dim() == 2, "This regularization is only supported for 2D weights"
             thresholds = torch.Tensor([threshold] * param.size(1)).cuda()
-            binary_map = self.threshold_policy(param, thresholds, threshold_criteria)
+            binary_map = self.threshold_policy(param, thresholds, threshold_criteria, dim=0)
             return binary_map.expand(param.size(0), param.size(1))
 
         elif group_type == '3D':
@@ -115,12 +115,12 @@ class GroupThresholdMixin(object):
             return d.view(param.size(0), param.size(1), param.size(2), param.size(3))
 
 
-    def threshold_policy(self, weights, thresholds, threshold_criteria):
+    def threshold_policy(self, weights, thresholds, threshold_criteria, dim=1):
         """
         """
         if threshold_criteria == 'Mean_Abs':
-            return weights.data.abs().mean(dim=1).gt(thresholds).type(weights.type())
+            return weights.data.abs().mean(dim=dim).gt(thresholds).type(weights.type())
         elif threshold_criteria == 'Max':
-            maxv, _ = weights.data.abs().max(dim=1)
+            maxv, _ = weights.data.abs().max(dim=dim)
             return maxv.gt(thresholds).type(weights.type())
         exit("Invalid threshold_criteria {}".format(threshold_criteria))
diff --git a/tests/test_thresholding.py b/tests/test_thresholding.py
new file mode 100755
index 0000000000000000000000000000000000000000..130dcdc42f44b5cd1883315b9f5ddb3abbcd3953
--- /dev/null
+++ b/tests/test_thresholding.py
@@ -0,0 +1,41 @@
+import torch
+import os
+import sys
+import pytest
+module_path = os.path.abspath(os.path.join('..'))
+if module_path not in sys.path:
+    sys.path.append(module_path)
+import distiller
+
+
+def get_test_tensor():
+    return torch.tensor([[1.0, 2.0, 3.0],
+                         [4.0, 5.0, 6.0],
+                         [7.0, 8.0, 9.0],
+                         [10., 11., 12.]])
+
+
+def test_row_thresholding():
+    p = get_test_tensor().cuda()
+    group_th = distiller.GroupThresholdMixin()
+    mask = group_th.group_threshold_mask(p, 'Rows', 7, 'Max')
+    assert torch.eq(mask, torch.tensor([[ 0.,  0.,  0.],
+                                        [ 0.,  0.,  0.],
+                                        [ 1.,  1.,  1.],
+                                        [ 1.,  1.,  1.]], device=mask.device)).all()
+    return mask
+
+
+def test_col_thresholding():
+    p = get_test_tensor().cuda()
+    group_th = distiller.GroupThresholdMixin()
+    mask = group_th.group_threshold_mask(p, 'Cols', 11, 'Max')
+    assert torch.eq(mask, torch.tensor([[ 0.,  0.,  1.],
+                                        [ 0.,  0.,  1.],
+                                        [ 0.,  0.,  1.],
+                                        [ 0.,  0.,  1.]], device=mask.device)).all()
+    return mask
+
+if __name__ == '__main__':
+    m = test_col_thresholding()
+    print(m)