diff --git a/distiller/thresholding.py b/distiller/thresholding.py
index 4f0b845e6ea7122f7e1d3b7eb9fabaecc9fe5353..52761ad58313294255b814010f329fcba2d5ceb6 100755
--- a/distiller/thresholding.py
+++ b/distiller/thresholding.py
@@ -40,7 +40,10 @@ class GroupThresholdMixin(object):
     TODO: this does not need to be a mixin - it should be made a simple function.  We keep this until we refactor
     """
     def group_threshold_mask(self, param, group_type, threshold, threshold_criteria):
-        return group_threshold_mask(param, group_type, threshold, threshold_criteria)
+        ret = group_threshold_mask(param, group_type, threshold, threshold_criteria)
+        if isinstance(ret, tuple):
+            return ret[0]
+        return ret
 
 
 def group_threshold_binary_map(param, group_type, threshold, threshold_criteria):
@@ -98,7 +101,7 @@ def group_threshold_binary_map(param, group_type, threshold, threshold_criteria)
             if param.data.abs().max() > threshold:
                 return None
             return torch.zeros_like(param.data)
-        exit("Invalid threshold_criteria {}".format(threshold_criteria))
+        raise ValueError("Invalid threshold_criteria {}".format(threshold_criteria))
 
     elif group_type == 'Channels':
         assert param.dim() == 4, "This thresholding is only supported for 4D weights"
@@ -135,17 +138,17 @@ def group_threshold_mask(param, group_type, threshold, threshold_criteria, binar
         # 3. Finally, expand the thresholds and view as a 4D tensor
         a = binary_map.expand(param.size(2) * param.size(3),
                               param.size(0) * param.size(1)).t()
-        return a.view(param.size(0), param.size(1), param.size(2), param.size(3))
+        return a.view(param.size(0), param.size(1), param.size(2), param.size(3)), binary_map
 
     elif group_type == 'Rows':
         if binary_map is None:
             binary_map = group_threshold_binary_map(param, group_type, threshold, threshold_criteria)
-        return binary_map.expand(param.size(1), param.size(0)).t()
+        return binary_map.expand(param.size(1), param.size(0)).t(), binary_map
 
     elif group_type == 'Cols':
         if binary_map is None:
             binary_map = group_threshold_binary_map(param, group_type, threshold, threshold_criteria)
-        return binary_map.expand(param.size(0), param.size(1))
+        return binary_map.expand(param.size(0), param.size(1)), binary_map
 
     elif group_type == '3D' or group_type == 'Filters':
         if binary_map is None:
@@ -163,7 +166,7 @@ def group_threshold_mask(param, group_type, threshold, threshold_criteria, binar
             if param.data.abs().max() > threshold:
                 return None
             return torch.zeros_like(param.data)
-        exit("Invalid threshold_criteria {}".format(threshold_criteria))
+        raise ValueError("Invalid threshold_criteria {}".format(threshold_criteria))
 
     elif group_type == 'Channels':
         if binary_map is None:
@@ -175,7 +178,7 @@ def group_threshold_mask(param, group_type, threshold, threshold_criteria, binar
         a = binary_map.expand(num_filters, num_kernels_per_filter)
         c = a.unsqueeze(-1)
         d = c.expand(num_filters, num_kernels_per_filter, param.size(2) * param.size(3)).contiguous()
-        return d.view(param.size(0), param.size(1), param.size(2), param.size(3))
+        return d.view(param.size(0), param.size(1), param.size(2), param.size(3)), binary_map
 
 
 def threshold_policy(weights, thresholds, threshold_criteria, dim=1):
@@ -192,4 +195,4 @@ def threshold_policy(weights, thresholds, threshold_criteria, dim=1):
     elif threshold_criteria == 'Max':
         maxv, _ = weights.data.abs().max(dim=dim)
         return maxv.gt(thresholds).type(weights.type())
-    exit("Invalid threshold_criteria {}".format(threshold_criteria))
+    raise ValueError("Invalid threshold_criteria {}".format(threshold_criteria))