Skip to content
Snippets Groups Projects
Commit 107e4825 authored by Guy Jacob's avatar Guy Jacob
Browse files

Range linear Quant-aware training: Require single GPU, and more

* Current implementation doesn't play nice with DataParallel, so allow
  only with single GPU for now
* Bug-fix: Update ranges only when training (not in eval)
* Refactor activation replacement function - explicit module instead of
  sequential
parent c98df541
No related branches found
No related tags found
No related merge requests found
...@@ -391,23 +391,38 @@ class FakeLinearQuantization(nn.Module): ...@@ -391,23 +391,38 @@ class FakeLinearQuantization(nn.Module):
self.register_buffer('zero_point', torch.zeros(1)) self.register_buffer('zero_point', torch.zeros(1))
def forward(self, input): def forward(self, input):
with torch.no_grad(): # We update the tracked stats only in training
current_min, current_max = get_tensor_min_max(input) #
self.iter_count = self.iter_count + 1 # Due to the way DataParallel works, we perform all updates in-place so the "main" device retains
self.tracked_min_biased, self.tracked_min = update_ema(self.tracked_min_biased, # its updates. (see https://pytorch.org/docs/stable/nn.html#dataparallel)
current_min, self.ema_decay, self.iter_count) # However, as it is now, the in-place update of iter_count causes an error when doing
self.tracked_max_biased, self.tracked_max = update_ema(self.tracked_max_biased, # back-prop with multiple GPUs, claiming a variable required for gradient calculation has been modified
current_max, self.ema_decay, self.iter_count) # in-place. Not clear why, since it's not used in any calculations that keep a gradient.
# It works fine with a single GPU. TODO: Debug...
if self.training:
with torch.no_grad():
current_min, current_max = get_tensor_min_max(input)
self.iter_count += 1
self.tracked_min_biased.data, self.tracked_min.data = update_ema(self.tracked_min_biased.data,
current_min, self.ema_decay,
self.iter_count)
self.tracked_max_biased.data, self.tracked_max.data = update_ema(self.tracked_max_biased.data,
current_max, self.ema_decay,
self.iter_count)
if self.mode == LinearQuantMode.SYMMETRIC: if self.mode == LinearQuantMode.SYMMETRIC:
max_abs = max(abs(self.tracked_min), abs(self.tracked_max)) max_abs = max(abs(self.tracked_min), abs(self.tracked_max))
actual_min, actual_max = -max_abs, max_abs actual_min, actual_max = -max_abs, max_abs
self.scale, self.zero_point = symmetric_linear_quantization_params(self.num_bits, max_abs) if self.training:
self.scale.data, self.zero_point.data = symmetric_linear_quantization_params(self.num_bits, max_abs)
else: else:
actual_min, actual_max = self.tracked_min, self.tracked_max actual_min, actual_max = self.tracked_min, self.tracked_max
signed = self.mode == LinearQuantMode.ASYMMETRIC_SIGNED signed = self.mode == LinearQuantMode.ASYMMETRIC_SIGNED
self.scale, self.zero_point = asymmetric_linear_quantization_params(self.num_bits, self.tracked_min, if self.training:
self.tracked_max, signed=signed) self.scale.data, self.zero_point.data = asymmetric_linear_quantization_params(self.num_bits,
self.tracked_min,
self.tracked_max,
signed=signed)
input = clamp(input, actual_min.item(), actual_max.item(), False) input = clamp(input, actual_min.item(), actual_max.item(), False)
input = LinearQuantizeSTE.apply(input, self.scale, self.zero_point, self.dequantize, False) input = LinearQuantizeSTE.apply(input, self.scale, self.zero_point, self.dequantize, False)
...@@ -419,6 +434,19 @@ class FakeLinearQuantization(nn.Module): ...@@ -419,6 +434,19 @@ class FakeLinearQuantization(nn.Module):
return 'mode={0}, num_bits={1}, ema_decay={2:.4f})'.format(mode_str, self.num_bits, self.ema_decay) return 'mode={0}, num_bits={1}, ema_decay={2:.4f})'.format(mode_str, self.num_bits, self.ema_decay)
class FakeQuantizationWrapper(nn.Module):
def __init__(self, wrapped_module, num_bits, quant_mode, ema_decay):
super(FakeQuantizationWrapper, self).__init__()
self.wrapped_module = wrapped_module
self.fake_q = FakeLinearQuantization(num_bits, quant_mode, ema_decay, dequantize=True,
inplace=getattr(wrapped_module, 'inplace', False))
def forward(self, *input):
res = self.wrapped_module(*input)
res = self.fake_q(res)
return res
class QuantAwareTrainRangeLinearQuantizer(Quantizer): class QuantAwareTrainRangeLinearQuantizer(Quantizer):
def __init__(self, model, optimizer=None, bits_activations=32, bits_weights=32, bits_overrides=OrderedDict(), def __init__(self, model, optimizer=None, bits_activations=32, bits_weights=32, bits_overrides=OrderedDict(),
quantize_bias=True, mode=LinearQuantMode.SYMMETRIC, ema_decay=0.999, per_channel_wts=False, quantize_bias=True, mode=LinearQuantMode.SYMMETRIC, ema_decay=0.999, per_channel_wts=False,
...@@ -430,6 +458,10 @@ class QuantAwareTrainRangeLinearQuantizer(Quantizer): ...@@ -430,6 +458,10 @@ class QuantAwareTrainRangeLinearQuantizer(Quantizer):
quantize_bias=quantize_bias, quantize_bias=quantize_bias,
train_with_fp_copy=True) train_with_fp_copy=True)
if isinstance(model, nn.DataParallel) and len(model.device_ids) > 1:
raise RuntimeError('QuantAwareTrainRangeLinearQuantizer currently does not support running with '
'multiple GPUs')
mode = verify_mode(mode) mode = verify_mode(mode)
self.model.quantizer_metadata['params']['mode'] = str(mode).split('.')[1] self.model.quantizer_metadata['params']['mode'] = str(mode).split('.')[1]
...@@ -458,16 +490,16 @@ class QuantAwareTrainRangeLinearQuantizer(Quantizer): ...@@ -458,16 +490,16 @@ class QuantAwareTrainRangeLinearQuantizer(Quantizer):
out = LinearQuantizeSTE.apply(param_fp, scale, zero_point, True, False) out = LinearQuantizeSTE.apply(param_fp, scale, zero_point, True, False)
return out return out
def relu_replace_fn(module, name, qbits_map): def activation_replace_fn(module, name, qbits_map):
bits_acts = qbits_map[name].acts bits_acts = qbits_map[name].acts
if bits_acts is None: if bits_acts is None:
return module return module
return nn.Sequential(module, FakeLinearQuantization(bits_acts, mode, ema_decay, dequantize=True, return FakeQuantizationWrapper(module, bits_acts, mode, ema_decay)
inplace=module.inplace))
self.param_quantization_fn = linear_quantize_param self.param_quantization_fn = linear_quantize_param
self.replacement_factory[nn.ReLU] = relu_replace_fn self.activation_replace_fn = activation_replace_fn
self.replacement_factory[nn.ReLU] = self.activation_replace_fn
def _prepare_model_impl(self): def _prepare_model_impl(self):
super(QuantAwareTrainRangeLinearQuantizer, self)._prepare_model_impl() super(QuantAwareTrainRangeLinearQuantizer, self)._prepare_model_impl()
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment