diff --git a/distiller/thresholding.py b/distiller/thresholding.py index f18899ba5a36039c78b46cf4bf5baabbcfdaaa2c..a00cbebe66c31206802cf68f3bf6e42ed20082db 100755 --- a/distiller/thresholding.py +++ b/distiller/thresholding.py @@ -73,7 +73,7 @@ class GroupThresholdMixin(object): elif group_type == 'Cols': assert param.dim() == 2, "This regularization is only supported for 2D weights" thresholds = torch.Tensor([threshold] * param.size(1)).cuda() - binary_map = self.threshold_policy(param, thresholds, threshold_criteria) + binary_map = self.threshold_policy(param, thresholds, threshold_criteria, dim=0) return binary_map.expand(param.size(0), param.size(1)) elif group_type == '3D': @@ -115,12 +115,12 @@ class GroupThresholdMixin(object): return d.view(param.size(0), param.size(1), param.size(2), param.size(3)) - def threshold_policy(self, weights, thresholds, threshold_criteria): + def threshold_policy(self, weights, thresholds, threshold_criteria, dim=1): """ """ if threshold_criteria == 'Mean_Abs': - return weights.data.abs().mean(dim=1).gt(thresholds).type(weights.type()) + return weights.data.abs().mean(dim=dim).gt(thresholds).type(weights.type()) elif threshold_criteria == 'Max': - maxv, _ = weights.data.abs().max(dim=1) + maxv, _ = weights.data.abs().max(dim=dim) return maxv.gt(thresholds).type(weights.type()) exit("Invalid threshold_criteria {}".format(threshold_criteria)) diff --git a/tests/test_thresholding.py b/tests/test_thresholding.py new file mode 100755 index 0000000000000000000000000000000000000000..130dcdc42f44b5cd1883315b9f5ddb3abbcd3953 --- /dev/null +++ b/tests/test_thresholding.py @@ -0,0 +1,41 @@ +import torch +import os +import sys +import pytest +module_path = os.path.abspath(os.path.join('..')) +if module_path not in sys.path: + sys.path.append(module_path) +import distiller + + +def get_test_tensor(): + return torch.tensor([[1.0, 2.0, 3.0], + [4.0, 5.0, 6.0], + [7.0, 8.0, 9.0], + [10., 11., 12.]]) + + +def test_row_thresholding(): + p = get_test_tensor().cuda() + group_th = distiller.GroupThresholdMixin() + mask = group_th.group_threshold_mask(p, 'Rows', 7, 'Max') + assert torch.eq(mask, torch.tensor([[ 0., 0., 0.], + [ 0., 0., 0.], + [ 1., 1., 1.], + [ 1., 1., 1.]], device=mask.device)).all() + return mask + + +def test_col_thresholding(): + p = get_test_tensor().cuda() + group_th = distiller.GroupThresholdMixin() + mask = group_th.group_threshold_mask(p, 'Cols', 11, 'Max') + assert torch.eq(mask, torch.tensor([[ 0., 0., 1.], + [ 0., 0., 1.], + [ 0., 0., 1.], + [ 0., 0., 1.]], device=mask.device)).all() + return mask + +if __name__ == '__main__': + m = test_col_thresholding() + print(m)