Skip to content
Snippets Groups Projects
Commit 642996cb authored by Neta Zmora's avatar Neta Zmora
Browse files

Group Lasso regularization: fix thresholding

Fixed the return value from GroupThresholdMixin.group_threshold_mask
so that it only returns the mask in all cases (this code is _not_
under test at the moment, and changes to the pruning code, which
also uses the thresholding code) led to this bug.
Need to add tests for group-lasso regularization.
parent 2daf616f
No related branches found
No related tags found
No related merge requests found
......@@ -40,7 +40,10 @@ class GroupThresholdMixin(object):
TODO: this does not need to be a mixin - it should be made a simple function. We keep this until we refactor
"""
def group_threshold_mask(self, param, group_type, threshold, threshold_criteria):
return group_threshold_mask(param, group_type, threshold, threshold_criteria)
ret = group_threshold_mask(param, group_type, threshold, threshold_criteria)
if isinstance(ret, tuple):
return ret[0]
return ret
def group_threshold_binary_map(param, group_type, threshold, threshold_criteria):
......@@ -98,7 +101,7 @@ def group_threshold_binary_map(param, group_type, threshold, threshold_criteria)
if param.data.abs().max() > threshold:
return None
return torch.zeros_like(param.data)
exit("Invalid threshold_criteria {}".format(threshold_criteria))
raise ValueError("Invalid threshold_criteria {}".format(threshold_criteria))
elif group_type == 'Channels':
assert param.dim() == 4, "This thresholding is only supported for 4D weights"
......@@ -135,17 +138,17 @@ def group_threshold_mask(param, group_type, threshold, threshold_criteria, binar
# 3. Finally, expand the thresholds and view as a 4D tensor
a = binary_map.expand(param.size(2) * param.size(3),
param.size(0) * param.size(1)).t()
return a.view(param.size(0), param.size(1), param.size(2), param.size(3))
return a.view(param.size(0), param.size(1), param.size(2), param.size(3)), binary_map
elif group_type == 'Rows':
if binary_map is None:
binary_map = group_threshold_binary_map(param, group_type, threshold, threshold_criteria)
return binary_map.expand(param.size(1), param.size(0)).t()
return binary_map.expand(param.size(1), param.size(0)).t(), binary_map
elif group_type == 'Cols':
if binary_map is None:
binary_map = group_threshold_binary_map(param, group_type, threshold, threshold_criteria)
return binary_map.expand(param.size(0), param.size(1))
return binary_map.expand(param.size(0), param.size(1)), binary_map
elif group_type == '3D' or group_type == 'Filters':
if binary_map is None:
......@@ -163,7 +166,7 @@ def group_threshold_mask(param, group_type, threshold, threshold_criteria, binar
if param.data.abs().max() > threshold:
return None
return torch.zeros_like(param.data)
exit("Invalid threshold_criteria {}".format(threshold_criteria))
raise ValueError("Invalid threshold_criteria {}".format(threshold_criteria))
elif group_type == 'Channels':
if binary_map is None:
......@@ -175,7 +178,7 @@ def group_threshold_mask(param, group_type, threshold, threshold_criteria, binar
a = binary_map.expand(num_filters, num_kernels_per_filter)
c = a.unsqueeze(-1)
d = c.expand(num_filters, num_kernels_per_filter, param.size(2) * param.size(3)).contiguous()
return d.view(param.size(0), param.size(1), param.size(2), param.size(3))
return d.view(param.size(0), param.size(1), param.size(2), param.size(3)), binary_map
def threshold_policy(weights, thresholds, threshold_criteria, dim=1):
......@@ -192,4 +195,4 @@ def threshold_policy(weights, thresholds, threshold_criteria, dim=1):
elif threshold_criteria == 'Max':
maxv, _ = weights.data.abs().max(dim=dim)
return maxv.gt(thresholds).type(weights.type())
exit("Invalid threshold_criteria {}".format(threshold_criteria))
raise ValueError("Invalid threshold_criteria {}".format(threshold_criteria))
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment