Skip to content
Snippets Groups Projects
Unverified Commit 58470d9f authored by Guy Jacob's avatar Guy Jacob Committed by GitHub
Browse files

ACIQ bug fixes

* Return both min and max clip value in the symmetric case
* Correct delta from mean in asymmetric + half_range case
parent 0036011d
No related branches found
No related tags found
No related merge requests found
...@@ -230,7 +230,9 @@ class AciqSymmetricClipper(AciqClipper): ...@@ -230,7 +230,9 @@ class AciqSymmetricClipper(AciqClipper):
mean = torch.tensor(t['mean']) mean = torch.tensor(t['mean'])
else: else:
mean = t.mean() mean = t.mean()
return torch.abs(mean) + alpha
clip_val = torch.abs(mean) + alpha
return -clip_val, clip_val
class AciqAsymmetricClipper(AciqClipper): class AciqAsymmetricClipper(AciqClipper):
...@@ -249,8 +251,8 @@ class AciqAsymmetricClipper(AciqClipper): ...@@ -249,8 +251,8 @@ class AciqAsymmetricClipper(AciqClipper):
else: else:
alpha = AciqClipper.get_alpha_gauss(t, across_dim, self.num_bits, half_range=half_range) alpha = AciqClipper.get_alpha_gauss(t, across_dim, self.num_bits, half_range=half_range)
min_val = torch.max(min_val, mean - alpha) min_val = torch.max(min_val, mean - alpha)
delta = alpha if half_range else 2 * alpha
return min_val, min_val + 2 * alpha return min_val, min_val + delta
def get_quantized_range(num_bits, signed=True): def get_quantized_range(num_bits, signed=True):
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment