diff --git a/distiller/quantization/range_linear.py b/distiller/quantization/range_linear.py
index 4da9e04287363d6fd67eea61ed07dc375e5af687..4eaafd2b4cfe26fe62bc0bfa3be080f48419ef00 100644
--- a/distiller/quantization/range_linear.py
+++ b/distiller/quantization/range_linear.py
@@ -391,23 +391,38 @@ class FakeLinearQuantization(nn.Module):
         self.register_buffer('zero_point', torch.zeros(1))
 
     def forward(self, input):
-        with torch.no_grad():
-            current_min, current_max = get_tensor_min_max(input)
-        self.iter_count = self.iter_count + 1
-        self.tracked_min_biased, self.tracked_min = update_ema(self.tracked_min_biased,
-                                                               current_min, self.ema_decay, self.iter_count)
-        self.tracked_max_biased, self.tracked_max = update_ema(self.tracked_max_biased,
-                                                               current_max, self.ema_decay, self.iter_count)
+        # We update the tracked stats only in training
+        #
+        # Due to the way DataParallel works, we perform all updates in-place so the "main" device retains
+        # its updates. (see https://pytorch.org/docs/stable/nn.html#dataparallel)
+        # However, as it is now, the in-place update of iter_count causes an error when doing
+        # back-prop with multiple GPUs, claiming a variable required for gradient calculation has been modified
+        # in-place. Not clear why, since it's not used in any calculations that keep a gradient.
+        # It works fine with a single GPU. TODO: Debug...
+        if self.training:
+            with torch.no_grad():
+                current_min, current_max = get_tensor_min_max(input)
+            self.iter_count += 1
+            self.tracked_min_biased.data, self.tracked_min.data = update_ema(self.tracked_min_biased.data,
+                                                                             current_min, self.ema_decay,
+                                                                             self.iter_count)
+            self.tracked_max_biased.data, self.tracked_max.data = update_ema(self.tracked_max_biased.data,
+                                                                             current_max, self.ema_decay,
+                                                                             self.iter_count)
 
         if self.mode == LinearQuantMode.SYMMETRIC:
             max_abs = max(abs(self.tracked_min), abs(self.tracked_max))
             actual_min, actual_max = -max_abs, max_abs
-            self.scale, self.zero_point = symmetric_linear_quantization_params(self.num_bits, max_abs)
+            if self.training:
+                self.scale.data, self.zero_point.data = symmetric_linear_quantization_params(self.num_bits, max_abs)
         else:
             actual_min, actual_max = self.tracked_min, self.tracked_max
             signed = self.mode == LinearQuantMode.ASYMMETRIC_SIGNED
-            self.scale, self.zero_point = asymmetric_linear_quantization_params(self.num_bits, self.tracked_min,
-                                                                                self.tracked_max, signed=signed)
+            if self.training:
+                self.scale.data, self.zero_point.data = asymmetric_linear_quantization_params(self.num_bits,
+                                                                                              self.tracked_min,
+                                                                                              self.tracked_max,
+                                                                                              signed=signed)
 
         input = clamp(input, actual_min.item(), actual_max.item(), False)
         input = LinearQuantizeSTE.apply(input, self.scale, self.zero_point, self.dequantize, False)
@@ -419,6 +434,19 @@ class FakeLinearQuantization(nn.Module):
         return 'mode={0}, num_bits={1}, ema_decay={2:.4f})'.format(mode_str, self.num_bits, self.ema_decay)
 
 
+class FakeQuantizationWrapper(nn.Module):
+    def __init__(self, wrapped_module, num_bits, quant_mode, ema_decay):
+        super(FakeQuantizationWrapper, self).__init__()
+        self.wrapped_module = wrapped_module
+        self.fake_q = FakeLinearQuantization(num_bits, quant_mode, ema_decay, dequantize=True,
+                                             inplace=getattr(wrapped_module, 'inplace', False))
+
+    def forward(self, *input):
+        res = self.wrapped_module(*input)
+        res = self.fake_q(res)
+        return res
+
+
 class QuantAwareTrainRangeLinearQuantizer(Quantizer):
     def __init__(self, model, optimizer=None, bits_activations=32, bits_weights=32, bits_overrides=OrderedDict(),
                  quantize_bias=True, mode=LinearQuantMode.SYMMETRIC, ema_decay=0.999, per_channel_wts=False,
@@ -430,6 +458,10 @@ class QuantAwareTrainRangeLinearQuantizer(Quantizer):
                                                                   quantize_bias=quantize_bias,
                                                                   train_with_fp_copy=True)
 
+        if isinstance(model, nn.DataParallel) and len(model.device_ids) > 1:
+            raise RuntimeError('QuantAwareTrainRangeLinearQuantizer currently does not support running with '
+                               'multiple GPUs')
+
         mode = verify_mode(mode)
 
         self.model.quantizer_metadata['params']['mode'] = str(mode).split('.')[1]
@@ -458,16 +490,16 @@ class QuantAwareTrainRangeLinearQuantizer(Quantizer):
             out = LinearQuantizeSTE.apply(param_fp, scale, zero_point, True, False)
             return out
 
-        def relu_replace_fn(module, name, qbits_map):
+        def activation_replace_fn(module, name, qbits_map):
             bits_acts = qbits_map[name].acts
             if bits_acts is None:
                 return module
-            return nn.Sequential(module, FakeLinearQuantization(bits_acts, mode, ema_decay, dequantize=True,
-                                                                inplace=module.inplace))
+            return FakeQuantizationWrapper(module, bits_acts, mode, ema_decay)
 
         self.param_quantization_fn = linear_quantize_param
 
-        self.replacement_factory[nn.ReLU] = relu_replace_fn
+        self.activation_replace_fn = activation_replace_fn
+        self.replacement_factory[nn.ReLU] = self.activation_replace_fn
 
     def _prepare_model_impl(self):
         super(QuantAwareTrainRangeLinearQuantizer, self)._prepare_model_impl()