From 89da7ce54b50e9e12dfe1f47a654ff01c20ca531 Mon Sep 17 00:00:00 2001
From: Neta Zmora <31280975+nzmora@users.noreply.github.com>
Date: Tue, 7 Aug 2018 08:10:44 -0500
Subject: [PATCH] Fix bug: thresholding matrix cols should use dim=0 (issue
 #39) (#40)

* Fix bug: thresholding matrix cols should use dim=0 (issue #39)

See issue #39 for a description of the bug from @vinutah.

* thresholding test: fix device assignment
---
 distiller/thresholding.py  |  8 ++++----
 tests/test_thresholding.py | 41 ++++++++++++++++++++++++++++++++++++++
 2 files changed, 45 insertions(+), 4 deletions(-)
 create mode 100755 tests/test_thresholding.py

diff --git a/distiller/thresholding.py b/distiller/thresholding.py
index f18899b..a00cbeb 100755
--- a/distiller/thresholding.py
+++ b/distiller/thresholding.py
@@ -73,7 +73,7 @@ class GroupThresholdMixin(object):
         elif group_type == 'Cols':
             assert param.dim() == 2, "This regularization is only supported for 2D weights"
             thresholds = torch.Tensor([threshold] * param.size(1)).cuda()
-            binary_map = self.threshold_policy(param, thresholds, threshold_criteria)
+            binary_map = self.threshold_policy(param, thresholds, threshold_criteria, dim=0)
             return binary_map.expand(param.size(0), param.size(1))
 
         elif group_type == '3D':
@@ -115,12 +115,12 @@ class GroupThresholdMixin(object):
             return d.view(param.size(0), param.size(1), param.size(2), param.size(3))
 
 
-    def threshold_policy(self, weights, thresholds, threshold_criteria):
+    def threshold_policy(self, weights, thresholds, threshold_criteria, dim=1):
         """
         """
         if threshold_criteria == 'Mean_Abs':
-            return weights.data.abs().mean(dim=1).gt(thresholds).type(weights.type())
+            return weights.data.abs().mean(dim=dim).gt(thresholds).type(weights.type())
         elif threshold_criteria == 'Max':
-            maxv, _ = weights.data.abs().max(dim=1)
+            maxv, _ = weights.data.abs().max(dim=dim)
             return maxv.gt(thresholds).type(weights.type())
         exit("Invalid threshold_criteria {}".format(threshold_criteria))
diff --git a/tests/test_thresholding.py b/tests/test_thresholding.py
new file mode 100755
index 0000000..130dcdc
--- /dev/null
+++ b/tests/test_thresholding.py
@@ -0,0 +1,41 @@
+import torch
+import os
+import sys
+import pytest
+module_path = os.path.abspath(os.path.join('..'))
+if module_path not in sys.path:
+    sys.path.append(module_path)
+import distiller
+
+
+def get_test_tensor():
+    return torch.tensor([[1.0, 2.0, 3.0],
+                         [4.0, 5.0, 6.0],
+                         [7.0, 8.0, 9.0],
+                         [10., 11., 12.]])
+
+
+def test_row_thresholding():
+    p = get_test_tensor().cuda()
+    group_th = distiller.GroupThresholdMixin()
+    mask = group_th.group_threshold_mask(p, 'Rows', 7, 'Max')
+    assert torch.eq(mask, torch.tensor([[ 0.,  0.,  0.],
+                                        [ 0.,  0.,  0.],
+                                        [ 1.,  1.,  1.],
+                                        [ 1.,  1.,  1.]], device=mask.device)).all()
+    return mask
+
+
+def test_col_thresholding():
+    p = get_test_tensor().cuda()
+    group_th = distiller.GroupThresholdMixin()
+    mask = group_th.group_threshold_mask(p, 'Cols', 11, 'Max')
+    assert torch.eq(mask, torch.tensor([[ 0.,  0.,  1.],
+                                        [ 0.,  0.,  1.],
+                                        [ 0.,  0.,  1.],
+                                        [ 0.,  0.,  1.]], device=mask.device)).all()
+    return mask
+
+if __name__ == '__main__':
+    m = test_col_thresholding()
+    print(m)
-- 
GitLab