diff --git a/distiller/quantization/range_linear.py b/distiller/quantization/range_linear.py index c4e797fd384b6a94d2ceb1f512aa6837c4e2c580..234bfe8f5587ebd5259908db538300c66233b215 100644 --- a/distiller/quantization/range_linear.py +++ b/distiller/quantization/range_linear.py @@ -359,7 +359,7 @@ class RangeLinearQuantParamLayerWrapper(RangeLinearQuantWrapper): self.has_bias = hasattr(wrapped_module, 'bias') and wrapped_module.bias is not None if self.has_bias: if self.preset_act_stats: - linear_quantize_clamp(wrapped_module.bias.data, self.accum_scale, 0, + linear_quantize_clamp(wrapped_module.bias.data, self.accum_scale.squeeze(), 0, self.accum_min_q_val, self.accum_max_q_val, inplace=True) else: b_scale, b_zero_point = _get_quant_params_from_tensor(wrapped_module.bias, num_bits_params, self.mode) @@ -387,7 +387,7 @@ class RangeLinearQuantParamLayerWrapper(RangeLinearQuantWrapper): if self.has_bias: # Re-quantize bias to match x * w scale: b_q' = (in_scale * w_scale / b_scale) * (b_q + b_zero_point) self.wrapped_module.bias.data = linear_quantize_clamp(self.base_b_q + self.b_zero_point, - self.accum_scale / self.b_scale, 0, + self.accum_scale.squeeze() / self.b_scale, 0, self.accum_min_q_val, self.accum_max_q_val) # Note the main terms within the summation is: