Skip to content
Snippets Groups Projects
Commit 642996cb authored by Neta Zmora's avatar Neta Zmora
Browse files

Group Lasso regularization: fix thresholding

Fixed the return value from GroupThresholdMixin.group_threshold_mask
so that it only returns the mask in all cases (this code is _not_
under test at the moment, and changes to the pruning code, which
also uses the thresholding code) led to this bug.
Need to add tests for group-lasso regularization.
parent 2daf616f
No related branches found
No related tags found
No related merge requests found
...@@ -40,7 +40,10 @@ class GroupThresholdMixin(object): ...@@ -40,7 +40,10 @@ class GroupThresholdMixin(object):
TODO: this does not need to be a mixin - it should be made a simple function. We keep this until we refactor TODO: this does not need to be a mixin - it should be made a simple function. We keep this until we refactor
""" """
def group_threshold_mask(self, param, group_type, threshold, threshold_criteria): def group_threshold_mask(self, param, group_type, threshold, threshold_criteria):
return group_threshold_mask(param, group_type, threshold, threshold_criteria) ret = group_threshold_mask(param, group_type, threshold, threshold_criteria)
if isinstance(ret, tuple):
return ret[0]
return ret
def group_threshold_binary_map(param, group_type, threshold, threshold_criteria): def group_threshold_binary_map(param, group_type, threshold, threshold_criteria):
...@@ -98,7 +101,7 @@ def group_threshold_binary_map(param, group_type, threshold, threshold_criteria) ...@@ -98,7 +101,7 @@ def group_threshold_binary_map(param, group_type, threshold, threshold_criteria)
if param.data.abs().max() > threshold: if param.data.abs().max() > threshold:
return None return None
return torch.zeros_like(param.data) return torch.zeros_like(param.data)
exit("Invalid threshold_criteria {}".format(threshold_criteria)) raise ValueError("Invalid threshold_criteria {}".format(threshold_criteria))
elif group_type == 'Channels': elif group_type == 'Channels':
assert param.dim() == 4, "This thresholding is only supported for 4D weights" assert param.dim() == 4, "This thresholding is only supported for 4D weights"
...@@ -135,17 +138,17 @@ def group_threshold_mask(param, group_type, threshold, threshold_criteria, binar ...@@ -135,17 +138,17 @@ def group_threshold_mask(param, group_type, threshold, threshold_criteria, binar
# 3. Finally, expand the thresholds and view as a 4D tensor # 3. Finally, expand the thresholds and view as a 4D tensor
a = binary_map.expand(param.size(2) * param.size(3), a = binary_map.expand(param.size(2) * param.size(3),
param.size(0) * param.size(1)).t() param.size(0) * param.size(1)).t()
return a.view(param.size(0), param.size(1), param.size(2), param.size(3)) return a.view(param.size(0), param.size(1), param.size(2), param.size(3)), binary_map
elif group_type == 'Rows': elif group_type == 'Rows':
if binary_map is None: if binary_map is None:
binary_map = group_threshold_binary_map(param, group_type, threshold, threshold_criteria) binary_map = group_threshold_binary_map(param, group_type, threshold, threshold_criteria)
return binary_map.expand(param.size(1), param.size(0)).t() return binary_map.expand(param.size(1), param.size(0)).t(), binary_map
elif group_type == 'Cols': elif group_type == 'Cols':
if binary_map is None: if binary_map is None:
binary_map = group_threshold_binary_map(param, group_type, threshold, threshold_criteria) binary_map = group_threshold_binary_map(param, group_type, threshold, threshold_criteria)
return binary_map.expand(param.size(0), param.size(1)) return binary_map.expand(param.size(0), param.size(1)), binary_map
elif group_type == '3D' or group_type == 'Filters': elif group_type == '3D' or group_type == 'Filters':
if binary_map is None: if binary_map is None:
...@@ -163,7 +166,7 @@ def group_threshold_mask(param, group_type, threshold, threshold_criteria, binar ...@@ -163,7 +166,7 @@ def group_threshold_mask(param, group_type, threshold, threshold_criteria, binar
if param.data.abs().max() > threshold: if param.data.abs().max() > threshold:
return None return None
return torch.zeros_like(param.data) return torch.zeros_like(param.data)
exit("Invalid threshold_criteria {}".format(threshold_criteria)) raise ValueError("Invalid threshold_criteria {}".format(threshold_criteria))
elif group_type == 'Channels': elif group_type == 'Channels':
if binary_map is None: if binary_map is None:
...@@ -175,7 +178,7 @@ def group_threshold_mask(param, group_type, threshold, threshold_criteria, binar ...@@ -175,7 +178,7 @@ def group_threshold_mask(param, group_type, threshold, threshold_criteria, binar
a = binary_map.expand(num_filters, num_kernels_per_filter) a = binary_map.expand(num_filters, num_kernels_per_filter)
c = a.unsqueeze(-1) c = a.unsqueeze(-1)
d = c.expand(num_filters, num_kernels_per_filter, param.size(2) * param.size(3)).contiguous() d = c.expand(num_filters, num_kernels_per_filter, param.size(2) * param.size(3)).contiguous()
return d.view(param.size(0), param.size(1), param.size(2), param.size(3)) return d.view(param.size(0), param.size(1), param.size(2), param.size(3)), binary_map
def threshold_policy(weights, thresholds, threshold_criteria, dim=1): def threshold_policy(weights, thresholds, threshold_criteria, dim=1):
...@@ -192,4 +195,4 @@ def threshold_policy(weights, thresholds, threshold_criteria, dim=1): ...@@ -192,4 +195,4 @@ def threshold_policy(weights, thresholds, threshold_criteria, dim=1):
elif threshold_criteria == 'Max': elif threshold_criteria == 'Max':
maxv, _ = weights.data.abs().max(dim=dim) maxv, _ = weights.data.abs().max(dim=dim)
return maxv.gt(thresholds).type(weights.type()) return maxv.gt(thresholds).type(weights.type())
exit("Invalid threshold_criteria {}".format(threshold_criteria)) raise ValueError("Invalid threshold_criteria {}".format(threshold_criteria))
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment