From fac2359bbd075879aa633aa6cfd57981bee1e250 Mon Sep 17 00:00:00 2001
From: Lev Zlotnik <46742999+levzlotnik@users.noreply.github.com>
Date: Wed, 1 May 2019 18:14:05 +0300
Subject: [PATCH] Post-train quant: Ensure quant params are located on correct
 device (#241)

---
 distiller/quantization/range_linear.py | 6 +++++-
 1 file changed, 5 insertions(+), 1 deletion(-)

diff --git a/distiller/quantization/range_linear.py b/distiller/quantization/range_linear.py
index 7dc7975..6a24304 100644
--- a/distiller/quantization/range_linear.py
+++ b/distiller/quantization/range_linear.py
@@ -259,6 +259,9 @@ class RangeLinearQuantWrapper(nn.Module):
     def forward(self, *inputs):
         if self.training:
             raise RuntimeError(self.__class__.__name__ + " can only be used in eval mode")
+        device = inputs[0].device
+        for buffer_name, buffer in self._buffers.items():
+            setattr(self, buffer_name, buffer.to(device))
 
         in_scales, in_zero_points = self.get_inputs_quantization_params(*inputs)
 
@@ -409,7 +412,8 @@ class RangeLinearQuantParamLayerWrapper(RangeLinearQuantWrapper):
         device = self.w_scale.device
 
         if self.preset_act_stats:
-            self.register_buffer('accum_scale', self.in_0_scale.to(device) * self.w_scale)
+            self.in_0_scale = self.in_0_scale.to(device)
+            self.register_buffer('accum_scale', self.in_0_scale * self.w_scale)
             if self.per_channel_wts:
                 self.accum_scale = self.accum_scale.squeeze(dim=-1)
         else:
-- 
GitLab