From fac2359bbd075879aa633aa6cfd57981bee1e250 Mon Sep 17 00:00:00 2001 From: Lev Zlotnik <46742999+levzlotnik@users.noreply.github.com> Date: Wed, 1 May 2019 18:14:05 +0300 Subject: [PATCH] Post-train quant: Ensure quant params are located on correct device (#241) --- distiller/quantization/range_linear.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/distiller/quantization/range_linear.py b/distiller/quantization/range_linear.py index 7dc7975..6a24304 100644 --- a/distiller/quantization/range_linear.py +++ b/distiller/quantization/range_linear.py @@ -259,6 +259,9 @@ class RangeLinearQuantWrapper(nn.Module): def forward(self, *inputs): if self.training: raise RuntimeError(self.__class__.__name__ + " can only be used in eval mode") + device = inputs[0].device + for buffer_name, buffer in self._buffers.items(): + setattr(self, buffer_name, buffer.to(device)) in_scales, in_zero_points = self.get_inputs_quantization_params(*inputs) @@ -409,7 +412,8 @@ class RangeLinearQuantParamLayerWrapper(RangeLinearQuantWrapper): device = self.w_scale.device if self.preset_act_stats: - self.register_buffer('accum_scale', self.in_0_scale.to(device) * self.w_scale) + self.in_0_scale = self.in_0_scale.to(device) + self.register_buffer('accum_scale', self.in_0_scale * self.w_scale) if self.per_channel_wts: self.accum_scale = self.accum_scale.squeeze(dim=-1) else: -- GitLab