diff --git a/distiller/thresholding.py b/distiller/thresholding.py index 4f0b845e6ea7122f7e1d3b7eb9fabaecc9fe5353..52761ad58313294255b814010f329fcba2d5ceb6 100755 --- a/distiller/thresholding.py +++ b/distiller/thresholding.py @@ -40,7 +40,10 @@ class GroupThresholdMixin(object): TODO: this does not need to be a mixin - it should be made a simple function. We keep this until we refactor """ def group_threshold_mask(self, param, group_type, threshold, threshold_criteria): - return group_threshold_mask(param, group_type, threshold, threshold_criteria) + ret = group_threshold_mask(param, group_type, threshold, threshold_criteria) + if isinstance(ret, tuple): + return ret[0] + return ret def group_threshold_binary_map(param, group_type, threshold, threshold_criteria): @@ -98,7 +101,7 @@ def group_threshold_binary_map(param, group_type, threshold, threshold_criteria) if param.data.abs().max() > threshold: return None return torch.zeros_like(param.data) - exit("Invalid threshold_criteria {}".format(threshold_criteria)) + raise ValueError("Invalid threshold_criteria {}".format(threshold_criteria)) elif group_type == 'Channels': assert param.dim() == 4, "This thresholding is only supported for 4D weights" @@ -135,17 +138,17 @@ def group_threshold_mask(param, group_type, threshold, threshold_criteria, binar # 3. Finally, expand the thresholds and view as a 4D tensor a = binary_map.expand(param.size(2) * param.size(3), param.size(0) * param.size(1)).t() - return a.view(param.size(0), param.size(1), param.size(2), param.size(3)) + return a.view(param.size(0), param.size(1), param.size(2), param.size(3)), binary_map elif group_type == 'Rows': if binary_map is None: binary_map = group_threshold_binary_map(param, group_type, threshold, threshold_criteria) - return binary_map.expand(param.size(1), param.size(0)).t() + return binary_map.expand(param.size(1), param.size(0)).t(), binary_map elif group_type == 'Cols': if binary_map is None: binary_map = group_threshold_binary_map(param, group_type, threshold, threshold_criteria) - return binary_map.expand(param.size(0), param.size(1)) + return binary_map.expand(param.size(0), param.size(1)), binary_map elif group_type == '3D' or group_type == 'Filters': if binary_map is None: @@ -163,7 +166,7 @@ def group_threshold_mask(param, group_type, threshold, threshold_criteria, binar if param.data.abs().max() > threshold: return None return torch.zeros_like(param.data) - exit("Invalid threshold_criteria {}".format(threshold_criteria)) + raise ValueError("Invalid threshold_criteria {}".format(threshold_criteria)) elif group_type == 'Channels': if binary_map is None: @@ -175,7 +178,7 @@ def group_threshold_mask(param, group_type, threshold, threshold_criteria, binar a = binary_map.expand(num_filters, num_kernels_per_filter) c = a.unsqueeze(-1) d = c.expand(num_filters, num_kernels_per_filter, param.size(2) * param.size(3)).contiguous() - return d.view(param.size(0), param.size(1), param.size(2), param.size(3)) + return d.view(param.size(0), param.size(1), param.size(2), param.size(3)), binary_map def threshold_policy(weights, thresholds, threshold_criteria, dim=1): @@ -192,4 +195,4 @@ def threshold_policy(weights, thresholds, threshold_criteria, dim=1): elif threshold_criteria == 'Max': maxv, _ = weights.data.abs().max(dim=dim) return maxv.gt(thresholds).type(weights.type()) - exit("Invalid threshold_criteria {}".format(threshold_criteria)) + raise ValueError("Invalid threshold_criteria {}".format(threshold_criteria))