diff --git a/distiller/quantization/range_linear.py b/distiller/quantization/range_linear.py
index 7dc79755573e3e3edec4952a1e477f7906928ac5..6a2430491e2e24a4eedc1c5434baff0c0ef4fade 100644
--- a/distiller/quantization/range_linear.py
+++ b/distiller/quantization/range_linear.py
@@ -259,6 +259,9 @@ class RangeLinearQuantWrapper(nn.Module):
     def forward(self, *inputs):
         if self.training:
             raise RuntimeError(self.__class__.__name__ + " can only be used in eval mode")
+        device = inputs[0].device
+        for buffer_name, buffer in self._buffers.items():
+            setattr(self, buffer_name, buffer.to(device))
 
         in_scales, in_zero_points = self.get_inputs_quantization_params(*inputs)
 
@@ -409,7 +412,8 @@ class RangeLinearQuantParamLayerWrapper(RangeLinearQuantWrapper):
         device = self.w_scale.device
 
         if self.preset_act_stats:
-            self.register_buffer('accum_scale', self.in_0_scale.to(device) * self.w_scale)
+            self.in_0_scale = self.in_0_scale.to(device)
+            self.register_buffer('accum_scale', self.in_0_scale * self.w_scale)
             if self.per_channel_wts:
                 self.accum_scale = self.accum_scale.squeeze(dim=-1)
         else: