diff --git a/distiller/quantization/range_linear.py b/distiller/quantization/range_linear.py index c484b9d66fbe71883b2002a3018226115042bd1c..4da9e04287363d6fd67eea61ed07dc375e5af687 100644 --- a/distiller/quantization/range_linear.py +++ b/distiller/quantization/range_linear.py @@ -359,7 +359,7 @@ class PostTrainLinearQuantizer(Quantizer): def update_ema(biased_ema, value, decay, step): biased_ema = biased_ema * decay + (1 - decay) * value unbiased_ema = biased_ema / (1 - decay ** step) # Bias correction - return unbiased_ema + return biased_ema, unbiased_ema def inputs_quantize_wrapped_forward(self, input): @@ -394,8 +394,10 @@ class FakeLinearQuantization(nn.Module): with torch.no_grad(): current_min, current_max = get_tensor_min_max(input) self.iter_count = self.iter_count + 1 - self.tracked_min = update_ema(self.tracked_min_biased, current_min, self.ema_decay, self.iter_count) - self.tracked_max = update_ema(self.tracked_max_biased, current_max, self.ema_decay, self.iter_count) + self.tracked_min_biased, self.tracked_min = update_ema(self.tracked_min_biased, + current_min, self.ema_decay, self.iter_count) + self.tracked_max_biased, self.tracked_max = update_ema(self.tracked_max_biased, + current_max, self.ema_decay, self.iter_count) if self.mode == LinearQuantMode.SYMMETRIC: max_abs = max(abs(self.tracked_min), abs(self.tracked_max))