diff --git a/distiller/quantization/range_linear.py b/distiller/quantization/range_linear.py index 4da9e04287363d6fd67eea61ed07dc375e5af687..4eaafd2b4cfe26fe62bc0bfa3be080f48419ef00 100644 --- a/distiller/quantization/range_linear.py +++ b/distiller/quantization/range_linear.py @@ -391,23 +391,38 @@ class FakeLinearQuantization(nn.Module): self.register_buffer('zero_point', torch.zeros(1)) def forward(self, input): - with torch.no_grad(): - current_min, current_max = get_tensor_min_max(input) - self.iter_count = self.iter_count + 1 - self.tracked_min_biased, self.tracked_min = update_ema(self.tracked_min_biased, - current_min, self.ema_decay, self.iter_count) - self.tracked_max_biased, self.tracked_max = update_ema(self.tracked_max_biased, - current_max, self.ema_decay, self.iter_count) + # We update the tracked stats only in training + # + # Due to the way DataParallel works, we perform all updates in-place so the "main" device retains + # its updates. (see https://pytorch.org/docs/stable/nn.html#dataparallel) + # However, as it is now, the in-place update of iter_count causes an error when doing + # back-prop with multiple GPUs, claiming a variable required for gradient calculation has been modified + # in-place. Not clear why, since it's not used in any calculations that keep a gradient. + # It works fine with a single GPU. TODO: Debug... + if self.training: + with torch.no_grad(): + current_min, current_max = get_tensor_min_max(input) + self.iter_count += 1 + self.tracked_min_biased.data, self.tracked_min.data = update_ema(self.tracked_min_biased.data, + current_min, self.ema_decay, + self.iter_count) + self.tracked_max_biased.data, self.tracked_max.data = update_ema(self.tracked_max_biased.data, + current_max, self.ema_decay, + self.iter_count) if self.mode == LinearQuantMode.SYMMETRIC: max_abs = max(abs(self.tracked_min), abs(self.tracked_max)) actual_min, actual_max = -max_abs, max_abs - self.scale, self.zero_point = symmetric_linear_quantization_params(self.num_bits, max_abs) + if self.training: + self.scale.data, self.zero_point.data = symmetric_linear_quantization_params(self.num_bits, max_abs) else: actual_min, actual_max = self.tracked_min, self.tracked_max signed = self.mode == LinearQuantMode.ASYMMETRIC_SIGNED - self.scale, self.zero_point = asymmetric_linear_quantization_params(self.num_bits, self.tracked_min, - self.tracked_max, signed=signed) + if self.training: + self.scale.data, self.zero_point.data = asymmetric_linear_quantization_params(self.num_bits, + self.tracked_min, + self.tracked_max, + signed=signed) input = clamp(input, actual_min.item(), actual_max.item(), False) input = LinearQuantizeSTE.apply(input, self.scale, self.zero_point, self.dequantize, False) @@ -419,6 +434,19 @@ class FakeLinearQuantization(nn.Module): return 'mode={0}, num_bits={1}, ema_decay={2:.4f})'.format(mode_str, self.num_bits, self.ema_decay) +class FakeQuantizationWrapper(nn.Module): + def __init__(self, wrapped_module, num_bits, quant_mode, ema_decay): + super(FakeQuantizationWrapper, self).__init__() + self.wrapped_module = wrapped_module + self.fake_q = FakeLinearQuantization(num_bits, quant_mode, ema_decay, dequantize=True, + inplace=getattr(wrapped_module, 'inplace', False)) + + def forward(self, *input): + res = self.wrapped_module(*input) + res = self.fake_q(res) + return res + + class QuantAwareTrainRangeLinearQuantizer(Quantizer): def __init__(self, model, optimizer=None, bits_activations=32, bits_weights=32, bits_overrides=OrderedDict(), quantize_bias=True, mode=LinearQuantMode.SYMMETRIC, ema_decay=0.999, per_channel_wts=False, @@ -430,6 +458,10 @@ class QuantAwareTrainRangeLinearQuantizer(Quantizer): quantize_bias=quantize_bias, train_with_fp_copy=True) + if isinstance(model, nn.DataParallel) and len(model.device_ids) > 1: + raise RuntimeError('QuantAwareTrainRangeLinearQuantizer currently does not support running with ' + 'multiple GPUs') + mode = verify_mode(mode) self.model.quantizer_metadata['params']['mode'] = str(mode).split('.')[1] @@ -458,16 +490,16 @@ class QuantAwareTrainRangeLinearQuantizer(Quantizer): out = LinearQuantizeSTE.apply(param_fp, scale, zero_point, True, False) return out - def relu_replace_fn(module, name, qbits_map): + def activation_replace_fn(module, name, qbits_map): bits_acts = qbits_map[name].acts if bits_acts is None: return module - return nn.Sequential(module, FakeLinearQuantization(bits_acts, mode, ema_decay, dequantize=True, - inplace=module.inplace)) + return FakeQuantizationWrapper(module, bits_acts, mode, ema_decay) self.param_quantization_fn = linear_quantize_param - self.replacement_factory[nn.ReLU] = relu_replace_fn + self.activation_replace_fn = activation_replace_fn + self.replacement_factory[nn.ReLU] = self.activation_replace_fn def _prepare_model_impl(self): super(QuantAwareTrainRangeLinearQuantizer, self)._prepare_model_impl()