diff --git a/distiller/quantization/range_linear.py b/distiller/quantization/range_linear.py index 7dc79755573e3e3edec4952a1e477f7906928ac5..6a2430491e2e24a4eedc1c5434baff0c0ef4fade 100644 --- a/distiller/quantization/range_linear.py +++ b/distiller/quantization/range_linear.py @@ -259,6 +259,9 @@ class RangeLinearQuantWrapper(nn.Module): def forward(self, *inputs): if self.training: raise RuntimeError(self.__class__.__name__ + " can only be used in eval mode") + device = inputs[0].device + for buffer_name, buffer in self._buffers.items(): + setattr(self, buffer_name, buffer.to(device)) in_scales, in_zero_points = self.get_inputs_quantization_params(*inputs) @@ -409,7 +412,8 @@ class RangeLinearQuantParamLayerWrapper(RangeLinearQuantWrapper): device = self.w_scale.device if self.preset_act_stats: - self.register_buffer('accum_scale', self.in_0_scale.to(device) * self.w_scale) + self.in_0_scale = self.in_0_scale.to(device) + self.register_buffer('accum_scale', self.in_0_scale * self.w_scale) if self.per_channel_wts: self.accum_scale = self.accum_scale.squeeze(dim=-1) else: