diff --git a/distiller/quantization/q_utils.py b/distiller/quantization/q_utils.py index 91d665ad3086f89c7ddba8a9158af1315d9b7518..e1bad3e91252051811fc0be76996458e0587256e 100644 --- a/distiller/quantization/q_utils.py +++ b/distiller/quantization/q_utils.py @@ -230,7 +230,9 @@ class AciqSymmetricClipper(AciqClipper): mean = torch.tensor(t['mean']) else: mean = t.mean() - return torch.abs(mean) + alpha + + clip_val = torch.abs(mean) + alpha + return -clip_val, clip_val class AciqAsymmetricClipper(AciqClipper): @@ -249,8 +251,8 @@ class AciqAsymmetricClipper(AciqClipper): else: alpha = AciqClipper.get_alpha_gauss(t, across_dim, self.num_bits, half_range=half_range) min_val = torch.max(min_val, mean - alpha) - - return min_val, min_val + 2 * alpha + delta = alpha if half_range else 2 * alpha + return min_val, min_val + delta def get_quantized_range(num_bits, signed=True):