Skip to content
Snippets Groups Projects
Commit 9af786b6 authored by Guy Jacob's avatar Guy Jacob
Browse files

Bug fix in n_stds clipping calculation (fix issue #249)

parent 5814bcd5
No related branches found
No related tags found
No related merge requests found
......@@ -135,7 +135,9 @@ def get_tensor_avg_max_abs(t, across_dim=None):
return torch.max(avg_min.abs_(), avg_max.abs_())
def get_tensor_mean_n_stds_min_max(t, n_stds=1):
def get_tensor_mean_n_stds_min_max(t, dim=None, n_stds=1):
if dim is not None:
raise NotImplementedError('Setting dim != None not supported yet')
if n_stds <= 0:
raise ValueError('n_stds must be > 0, got {}'.format(n_stds))
mean = t.mean()
......@@ -146,8 +148,8 @@ def get_tensor_mean_n_stds_min_max(t, n_stds=1):
return min_val, max_val
def get_tensor_mean_n_stds_max_abs(t, n_stds=1):
min_val, max_val = get_tensor_mean_n_stds_min_max(t, n_stds)
def get_tensor_mean_n_stds_max_abs(t, dim=None, n_stds=1):
min_val, max_val = get_tensor_mean_n_stds_min_max(t, dim, n_stds)
return torch.max(min_val.abs_(), max_val.abs_())
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment