From 58470d9fa37b0455480ec0fa6a8b49ef2a44b8d7 Mon Sep 17 00:00:00 2001
From: Guy Jacob <guy.jacob@intel.com>
Date: Tue, 24 Sep 2019 09:45:02 +0300
Subject: [PATCH] ACIQ bug fixes

* Return both min and max clip value in the symmetric case
* Correct delta from mean in asymmetric + half_range case
---
 distiller/quantization/q_utils.py | 8 +++++---
 1 file changed, 5 insertions(+), 3 deletions(-)

diff --git a/distiller/quantization/q_utils.py b/distiller/quantization/q_utils.py
index 91d665a..e1bad3e 100644
--- a/distiller/quantization/q_utils.py
+++ b/distiller/quantization/q_utils.py
@@ -230,7 +230,9 @@ class AciqSymmetricClipper(AciqClipper):
             mean = torch.tensor(t['mean'])
         else:
             mean = t.mean()
-        return torch.abs(mean) + alpha
+
+        clip_val = torch.abs(mean) + alpha
+        return -clip_val, clip_val
 
 
 class AciqAsymmetricClipper(AciqClipper):
@@ -249,8 +251,8 @@ class AciqAsymmetricClipper(AciqClipper):
         else:
             alpha = AciqClipper.get_alpha_gauss(t, across_dim, self.num_bits, half_range=half_range)
         min_val = torch.max(min_val, mean - alpha)
-
-        return min_val, min_val + 2 * alpha
+        delta = alpha if half_range else 2 * alpha
+        return min_val, min_val + delta
 
 
 def get_quantized_range(num_bits, signed=True):
-- 
GitLab