From 51880a226236684ed2518c4b24e03ea411d3af35 Mon Sep 17 00:00:00 2001
From: Guy Jacob <guy.jacob@intel.com>
Date: Sun, 9 Dec 2018 18:05:49 +0200
Subject: [PATCH] Bugfix in EMA calculation in FakeLinearQuantization

---
 distiller/quantization/range_linear.py | 8 +++++---
 1 file changed, 5 insertions(+), 3 deletions(-)

diff --git a/distiller/quantization/range_linear.py b/distiller/quantization/range_linear.py
index c484b9d..4da9e04 100644
--- a/distiller/quantization/range_linear.py
+++ b/distiller/quantization/range_linear.py
@@ -359,7 +359,7 @@ class PostTrainLinearQuantizer(Quantizer):
 def update_ema(biased_ema, value, decay, step):
     biased_ema = biased_ema * decay + (1 - decay) * value
     unbiased_ema = biased_ema / (1 - decay ** step)  # Bias correction
-    return unbiased_ema
+    return biased_ema, unbiased_ema
 
 
 def inputs_quantize_wrapped_forward(self, input):
@@ -394,8 +394,10 @@ class FakeLinearQuantization(nn.Module):
         with torch.no_grad():
             current_min, current_max = get_tensor_min_max(input)
         self.iter_count = self.iter_count + 1
-        self.tracked_min = update_ema(self.tracked_min_biased, current_min, self.ema_decay, self.iter_count)
-        self.tracked_max = update_ema(self.tracked_max_biased, current_max, self.ema_decay, self.iter_count)
+        self.tracked_min_biased, self.tracked_min = update_ema(self.tracked_min_biased,
+                                                               current_min, self.ema_decay, self.iter_count)
+        self.tracked_max_biased, self.tracked_max = update_ema(self.tracked_max_biased,
+                                                               current_max, self.ema_decay, self.iter_count)
 
         if self.mode == LinearQuantMode.SYMMETRIC:
             max_abs = max(abs(self.tracked_min), abs(self.tracked_max))
-- 
GitLab