From 51880a226236684ed2518c4b24e03ea411d3af35 Mon Sep 17 00:00:00 2001 From: Guy Jacob <guy.jacob@intel.com> Date: Sun, 9 Dec 2018 18:05:49 +0200 Subject: [PATCH] Bugfix in EMA calculation in FakeLinearQuantization --- distiller/quantization/range_linear.py | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/distiller/quantization/range_linear.py b/distiller/quantization/range_linear.py index c484b9d..4da9e04 100644 --- a/distiller/quantization/range_linear.py +++ b/distiller/quantization/range_linear.py @@ -359,7 +359,7 @@ class PostTrainLinearQuantizer(Quantizer): def update_ema(biased_ema, value, decay, step): biased_ema = biased_ema * decay + (1 - decay) * value unbiased_ema = biased_ema / (1 - decay ** step) # Bias correction - return unbiased_ema + return biased_ema, unbiased_ema def inputs_quantize_wrapped_forward(self, input): @@ -394,8 +394,10 @@ class FakeLinearQuantization(nn.Module): with torch.no_grad(): current_min, current_max = get_tensor_min_max(input) self.iter_count = self.iter_count + 1 - self.tracked_min = update_ema(self.tracked_min_biased, current_min, self.ema_decay, self.iter_count) - self.tracked_max = update_ema(self.tracked_max_biased, current_max, self.ema_decay, self.iter_count) + self.tracked_min_biased, self.tracked_min = update_ema(self.tracked_min_biased, + current_min, self.ema_decay, self.iter_count) + self.tracked_max_biased, self.tracked_max = update_ema(self.tracked_max_biased, + current_max, self.ema_decay, self.iter_count) if self.mode == LinearQuantMode.SYMMETRIC: max_abs = max(abs(self.tracked_min), abs(self.tracked_max)) -- GitLab