From e82d938077f9c5abc81b79cfadf245f44fabe0ba Mon Sep 17 00:00:00 2001 From: Guy Jacob <guy.jacob@intel.com> Date: Mon, 6 Jan 2020 13:38:27 +0200 Subject: [PATCH] Post-train quant: Refactor inputs quantization (#454) * Fake quant wrapper now also works on (fake) quantized inputs * Remove 'requires_quantized_inputs' flag * Unrelated: Moved LinearQuantMode enum to q_utils --- distiller/quantization/__init__.py | 6 +- distiller/quantization/ptq_greedy_search.py | 3 +- distiller/quantization/q_utils.py | 6 + distiller/quantization/range_linear.py | 238 ++++++++++---------- 4 files changed, 132 insertions(+), 121 deletions(-) diff --git a/distiller/quantization/__init__.py b/distiller/quantization/__init__.py index e24c976..553877f 100644 --- a/distiller/quantization/__init__.py +++ b/distiller/quantization/__init__.py @@ -16,9 +16,11 @@ from .quantizer import Quantizer from .range_linear import RangeLinearQuantWrapper, RangeLinearQuantParamLayerWrapper, PostTrainLinearQuantizer, \ - LinearQuantMode, QuantAwareTrainRangeLinearQuantizer, add_post_train_quant_args, NCFQuantAwareTrainQuantizer, \ - RangeLinearQuantConcatWrapper, RangeLinearQuantEltwiseAddWrapper, RangeLinearQuantEltwiseMultWrapper, ClipMode + QuantAwareTrainRangeLinearQuantizer, add_post_train_quant_args, NCFQuantAwareTrainQuantizer, \ + RangeLinearQuantConcatWrapper, RangeLinearQuantEltwiseAddWrapper, RangeLinearQuantEltwiseMultWrapper, ClipMode, \ + RangeLinearEmbeddingWrapper, RangeLinearFakeQuantWrapper, RangeLinearQuantMatmulWrapper from .clipped_linear import LinearQuantizeSTE, ClippedLinearQuantization, WRPNQuantizer, DorefaQuantizer, PACTQuantizer +from .q_utils import * del quantizer del range_linear diff --git a/distiller/quantization/ptq_greedy_search.py b/distiller/quantization/ptq_greedy_search.py index 5333c4e..46d3b45 100644 --- a/distiller/quantization/ptq_greedy_search.py +++ b/distiller/quantization/ptq_greedy_search.py @@ -18,7 +18,8 @@ Here we implement the greedy search algorithm for automatic quantization. """ import torch import torch.nn as nn -from distiller.quantization.range_linear import PostTrainLinearQuantizer, ClipMode, LinearQuantMode +from distiller.quantization import LinearQuantMode +from distiller.quantization.range_linear import PostTrainLinearQuantizer, ClipMode from distiller.summary_graph import SummaryGraph from distiller.model_transforms import fold_batch_norms import distiller.modules diff --git a/distiller/quantization/q_utils.py b/distiller/quantization/q_utils.py index e1bad3e..8145dc7 100644 --- a/distiller/quantization/q_utils.py +++ b/distiller/quantization/q_utils.py @@ -18,6 +18,12 @@ from enum import Enum import torch +class LinearQuantMode(Enum): + SYMMETRIC = 1 + ASYMMETRIC_UNSIGNED = 2 + ASYMMETRIC_SIGNED = 3 + + def _prep_saturation_val_tensor(sat_val): is_scalar = not isinstance(sat_val, torch.Tensor) out = torch.tensor(sat_val) if is_scalar else sat_val.clone().detach() diff --git a/distiller/quantization/range_linear.py b/distiller/quantization/range_linear.py index 0765e33..e567700 100644 --- a/distiller/quantization/range_linear.py +++ b/distiller/quantization/range_linear.py @@ -52,12 +52,6 @@ def _enum_to_str(enum_val): return str(enum_val).split('.')[1] -class LinearQuantMode(Enum): - SYMMETRIC = 1 - ASYMMETRIC_UNSIGNED = 2 - ASYMMETRIC_SIGNED = 3 - - class ModuleQuantMode(namedtuple('ModuleQuantMode', ['activations', 'weights'])): """ Named tuple for configuring the LinearQuantMode of both weights and activations of a module @@ -227,12 +221,21 @@ class QuantSettings(object): def linear_quantize_clamp_with_metadata(t, inplace=False): - return linear_quantize_clamp(t, *t.quant_metadata, inplace) + assert hasattr(t, 'quant_metadata') + qmd = t.quant_metadata + t = linear_quantize_clamp(t, *qmd, inplace) + if not inplace: + t.quant_metadata = qmd + return t def linear_dequantize_with_metadata(t, inplace=False): + assert hasattr(t, 'quant_metadata') qmd = t.quant_metadata - return linear_dequantize(t, qmd.scale, qmd.zero_point, inplace) + t = linear_dequantize(t, qmd.scale, qmd.zero_point, inplace) + if not inplace: + t.quant_metadata = qmd + return t def add_post_train_quant_args(argparser): @@ -331,8 +334,7 @@ class RangeLinearQuantWrapper(nn.Module): def __init__(self, wrapped_module, num_bits_acts, num_bits_accum=32, mode=LinearQuantMode.SYMMETRIC, clip_acts=ClipMode.NONE, activation_stats=None, clip_n_stds=None, clip_half_range=False, - scale_approx_mult_bits=None, - input_overrides=None, requires_quantized_inputs=True, inputs_quant_auto_fallback=False): + scale_approx_mult_bits=None, input_overrides=None, inputs_quant_auto_fallback=False): super(RangeLinearQuantWrapper, self).__init__() input_overrides = input_overrides or OrderedDict() @@ -342,7 +344,6 @@ class RangeLinearQuantWrapper(nn.Module): self.wrapped_module = wrapped_module self.clip_half_range = clip_half_range self.scale_approx_mult_bits = scale_approx_mult_bits - self.requires_quantized_inputs = requires_quantized_inputs self.inputs_quant_auto_fallback = inputs_quant_auto_fallback self.output_quant_settings = QuantSettings(num_bits_acts, mode.activations, clip_acts, clip_n_stds, @@ -358,26 +359,25 @@ class RangeLinearQuantWrapper(nn.Module): return # Activations are quantized - setup quantization parameters - if self.requires_quantized_inputs: - self.inputs_quant_settings_overrides = OrderedDict() - for k, v in input_overrides.items(): - idx = int(k) - if v.pop('from_output', None): - quant_settings = deepcopy(self.output_quant_settings) - quant_settings.clip_half_range = False - else: - quant_settings = QuantSettings( - v.pop('bits_activations', self.output_quant_settings.num_bits), - verify_quant_mode(v.pop('mode', self.output_quant_settings.quant_mode)), - verify_clip_mode(v.pop('clip_acts', self.output_quant_settings.clip_mode)), - v.pop('clip_n_stds', self.output_quant_settings.clip_n_stds), - False, False) - if v: - # Poor man's input checking on input overrides dict - raise ValueError('Input overrides dict contains unsupported keys:', list(v.keys())) - self.inputs_quant_settings_overrides[idx] = quant_settings - else: - self.inputs_quant_settings_overrides = None + + # Set-up inputs quantization settings + self.inputs_quant_settings_overrides = OrderedDict() + for k, v in input_overrides.items(): + idx = int(k) + if v.pop('from_output', None): + quant_settings = deepcopy(self.output_quant_settings) + quant_settings.clip_half_range = False + else: + quant_settings = QuantSettings( + v.pop('bits_activations', self.output_quant_settings.num_bits), + verify_quant_mode(v.pop('mode', self.output_quant_settings.quant_mode)), + verify_clip_mode(v.pop('clip_acts', self.output_quant_settings.clip_mode)), + v.pop('clip_n_stds', self.output_quant_settings.clip_n_stds), + False, False) + if v: + # Poor man's input checking on input overrides dict + raise ValueError('Input overrides dict contains unsupported keys:', list(v.keys())) + self.inputs_quant_settings_overrides[idx] = quant_settings # Controls whether output is de-quantized at end of forward op. Meant as a debug / test flag only # (note that if False, the quantized output will be returned, but without any quantization parameters, @@ -392,21 +392,20 @@ class RangeLinearQuantWrapper(nn.Module): if activation_stats: self.preset_act_stats = True - if self.requires_quantized_inputs: - self.inputs_quant_metadata_fallback = OrderedDict() - for idx, stats in activation_stats['inputs'].items(): - settings = self.inputs_quant_settings_overrides.get(idx, self.output_quant_settings) - scale, zp = _get_quant_params_from_stats_dict( - stats, settings.num_bits, settings.quant_mode, settings.clip_mode, - settings.clip_n_stds, settings.clip_half_range, self.scale_approx_mult_bits - ) - min_q_val, max_q_val = get_quantized_range( - settings.num_bits, settings.quant_mode != LinearQuantMode.ASYMMETRIC_UNSIGNED) - qmd = TensorQuantMetadata(scale, zp, min_q_val, max_q_val) - self.inputs_quant_metadata_fallback[idx] = qmd - else: - self.inputs_quant_metadata_fallback = None - + # Calculate inputs quantization parameters + self.inputs_quant_metadata_fallback = OrderedDict() + for idx, stats in activation_stats['inputs'].items(): + settings = self.inputs_quant_settings_overrides.get(idx, self.output_quant_settings) + scale, zp = _get_quant_params_from_stats_dict( + stats, settings.num_bits, settings.quant_mode, settings.clip_mode, + settings.clip_n_stds, settings.clip_half_range, self.scale_approx_mult_bits + ) + min_q_val, max_q_val = get_quantized_range( + settings.num_bits, settings.quant_mode != LinearQuantMode.ASYMMETRIC_UNSIGNED) + qmd = TensorQuantMetadata(scale, zp, min_q_val, max_q_val) + self.inputs_quant_metadata_fallback[idx] = qmd + + # Calculate output quantization parameters scale, zp = _get_quant_params_from_stats_dict(activation_stats['output'], num_bits_acts, mode.activations, clip_acts, clip_n_stds, clip_half_range, scale_approx_mult_bits) @@ -436,19 +435,7 @@ class RangeLinearQuantWrapper(nn.Module): for buffer_name, buffer in self._buffers.items(): setattr(self, buffer_name, buffer.to(device)) - if self.requires_quantized_inputs: - self._prepare_inputs_for_quantization(inputs) - - inputs_q = [] - for input in inputs: - qmd = input.quant_metadata - input.quant_metadata = TensorQuantMetadata(qmd.scale.to(device), qmd.zero_point.to(device), - qmd.min_q_val, qmd.max_q_val) - input_q = linear_quantize_clamp_with_metadata(input) - input_q.quant_metadata = input.quant_metadata - inputs_q.append(input_q) - else: - inputs_q = inputs + inputs_q = [self._prepare_input(idx, input) for idx, input in enumerate(inputs)] # Forward through wrapped module accum = self.quantized_forward(*inputs_q) @@ -474,42 +461,50 @@ class RangeLinearQuantWrapper(nn.Module): return out_f - def _prepare_inputs_for_quantization(self, inputs): - for idx, input in enumerate(inputs): - if hasattr(input, 'quant_metadata'): - if idx in self.inputs_quant_settings_overrides: - raise RuntimeError('<{}> Input {}: CONFLICT - Tensor has embedded quantization metadata AND user ' - 'defined input quantization settings'.format(self.distiller_name, idx)) + def _prepare_input(self, idx, input): + # Default implementation - quantize the input tensor + # This works for all but RangeLinearFakeQuantWrapper + input.quant_metadata = self._get_input_quant_metadata(idx, input) + return linear_quantize_clamp_with_metadata(input, inplace=False) + + def _get_input_quant_metadata(self, idx, input): + if hasattr(input, 'quant_metadata'): + if idx in self.inputs_quant_settings_overrides: + raise RuntimeError('<{}> Input {}: CONFLICT - Tensor has embedded quantization metadata AND user ' + 'defined input quantization settings'.format(self.distiller_name, idx)) + qmd = input.quant_metadata + else: + # Input doesn't have embedded quantization data propagated from a previous layer + # Our options are: + # If user set explicit settings for this input, use those + # OR + # If auto fallback is set, use the output quantization settings + if idx not in self.inputs_quant_settings_overrides and not self.inputs_quant_auto_fallback: + raise RuntimeError('<{}> Input {}: Expected tensor with embedded quantization metadata. Either:\n' + '1. Make sure the previous operation is quantized\n' + '2. Provide explicit input quantization settings\n' + '3. Set inputs_quant_auto_fallback'.format(self.distiller_name, idx)) + if self.preset_act_stats: + qmd = self.inputs_quant_metadata_fallback[idx] else: - # Input doesn't have embedded quantization data propagated from a previous layer - # Our options are: - # If user set explicit settings for this input, use those - # OR - # If auto fallback is set, use the output quantization settings - if idx not in self.inputs_quant_settings_overrides and not self.inputs_quant_auto_fallback: - raise RuntimeError('<{}> Input {}: Expected tensor with embedded quantization metadata. Either:\n' - '1. Make sure the previous operation is quantized\n' - '2. Provide explicit input quantization settings\n' - '3. Set inputs_quant_auto_fallback'.format(self.distiller_name, idx)) - if self.preset_act_stats: - input.quant_metadata = self.inputs_quant_metadata_fallback[idx] + if idx in self.inputs_quant_settings_overrides: + q_settings = self.inputs_quant_settings_overrides[idx] else: - if idx in self.inputs_quant_settings_overrides: - q_settings = self.inputs_quant_settings_overrides[idx] - else: - # If we're here then inputs_quant_auto_fallback is set - # if self.num_forwards == 0: - # msglogger.info('<{}> Input {}: No embedded quantization metadata, ' - # 'falling back to output settings'.format(self.distiller_name, idx)) - q_settings = deepcopy(self.output_quant_settings) - q_settings.clip_half_range = False - scale, zp = _get_quant_params_from_tensor(input, q_settings.num_bits, q_settings.quant_mode, - q_settings.clip_mode, q_settings.per_channel, - q_settings.clip_n_stds, q_settings.clip_half_range, - self.scale_approx_mult_bits) - signed = q_settings.quant_mode != LinearQuantMode.ASYMMETRIC_UNSIGNED - min_q_val, max_q_val = get_quantized_range(q_settings.num_bits, signed) - input.quant_metadata = TensorQuantMetadata(scale, zp, min_q_val, max_q_val) + # If we're here then inputs_quant_auto_fallback is set + q_settings = deepcopy(self.output_quant_settings) + q_settings.clip_half_range = False + scale, zp = _get_quant_params_from_tensor(input, q_settings.num_bits, q_settings.quant_mode, + q_settings.clip_mode, q_settings.per_channel, + q_settings.clip_n_stds, q_settings.clip_half_range, + self.scale_approx_mult_bits) + signed = q_settings.quant_mode != LinearQuantMode.ASYMMETRIC_UNSIGNED + min_q_val, max_q_val = get_quantized_range(q_settings.num_bits, signed) + qmd = TensorQuantMetadata(scale, zp, min_q_val, max_q_val) + + # Make sure scale and zp are on correct device + qmd = TensorQuantMetadata(qmd.scale.to(input.device), qmd.zero_point.to(input.device), + qmd.min_q_val, qmd.max_q_val) + return qmd def quantized_forward(self, *inputs_q): """ @@ -554,24 +549,21 @@ class RangeLinearQuantWrapper(nn.Module): tmpstr = 'output_quant_settings={0}'.format(self.output_quant_settings) tmpstr += '\naccum_quant_settings={0}'.format(self.accum_quant_settings) - tmpstr += '\nrequires_quantized_inputs={0}'.format(self.requires_quantized_inputs) - if self.requires_quantized_inputs: - overrides = self.inputs_quant_settings_overrides - tmpstr += '\n inputs_quant_auto_fallback={}'.format(self.inputs_quant_auto_fallback) - tmpstr += ', forced_quant_settings_for_inputs={}'.format( - 'None' if not overrides else list(overrides.keys())) - for idx, qset in overrides.items(): - tmpstr += '\n input_{}_settings={}'.format(idx, qset) + overrides = self.inputs_quant_settings_overrides + tmpstr += '\n inputs_quant_auto_fallback={}'.format(self.inputs_quant_auto_fallback) + tmpstr += ', forced_quant_settings_for_inputs={}'.format( + 'None' if not overrides else list(overrides.keys())) + for idx, qset in overrides.items(): + tmpstr += '\n input_{}_settings={}'.format(idx, qset) tmpstr += '\nscale_approx_mult_bits={}'.format(self.scale_approx_mult_bits) tmpstr += '\npreset_activation_stats={0}'.format(self.preset_act_stats) if self.preset_act_stats: tmpstr += '\n output_scale={0}, output_zero_point={1}'.format(_quant_param_to_str( self.output_scale), _quant_param_to_str(self.output_zero_point)) - if self.requires_quantized_inputs: - for idx in self.inputs_quant_settings_overrides: - qmd = self.inputs_quant_metadata_fallback[idx] - tmpstr += '\n input_#{0}_scale={1}, input_#{0}_zero_point={2}'.format( - idx, _quant_param_to_str(qmd.scale), _quant_param_to_str(qmd.zero_point)) + for idx in self.inputs_quant_settings_overrides: + qmd = self.inputs_quant_metadata_fallback[idx] + tmpstr += '\n input_#{0}_scale={1}, input_#{0}_zero_point={2}'.format( + idx, _quant_param_to_str(qmd.scale), _quant_param_to_str(qmd.zero_point)) return tmpstr @@ -622,7 +614,6 @@ class RangeLinearQuantParamLayerWrapper(RangeLinearQuantWrapper): clip_acts, activation_stats, clip_n_stds, clip_half_range, scale_approx_mult_bits, input_overrides=input_overrides, - requires_quantized_inputs=True, inputs_quant_auto_fallback=inputs_quant_auto_fallback) if not isinstance(wrapped_module, (nn.Conv2d, nn.Conv3d, nn.Linear)): @@ -807,7 +798,6 @@ class RangeLinearQuantMatmulWrapper(RangeLinearQuantWrapper): clip_acts, activation_stats, clip_n_stds, clip_half_range, scale_approx_mult_bits, input_overrides=input_overrides, - requires_quantized_inputs=True, inputs_quant_auto_fallback=inputs_quant_auto_fallback) if not isinstance(wrapped_module, (distiller.modules.Matmul, distiller.modules.BatchMatmul)): @@ -859,7 +849,6 @@ class RangeLinearQuantConcatWrapper(RangeLinearQuantWrapper): clip_n_stds=clip_n_stds, clip_half_range=clip_half_range, scale_approx_mult_bits=scale_approx_mult_bits, input_overrides=input_overrides, - requires_quantized_inputs=True, inputs_quant_auto_fallback=inputs_quant_auto_fallback) def quantized_forward(self, *inputs_q): @@ -895,7 +884,6 @@ class RangeLinearQuantEltwiseAddWrapper(RangeLinearQuantWrapper): clip_n_stds=clip_n_stds, clip_half_range=clip_half_range, scale_approx_mult_bits=scale_approx_mult_bits, input_overrides=input_overrides, - requires_quantized_inputs=True, inputs_quant_auto_fallback=inputs_quant_auto_fallback) def quantized_forward(self, *inputs_q): @@ -932,7 +920,6 @@ class RangeLinearQuantEltwiseMultWrapper(RangeLinearQuantWrapper): clip_n_stds=clip_n_stds, clip_half_range=clip_half_range, scale_approx_mult_bits=scale_approx_mult_bits, input_overrides=input_overrides, - requires_quantized_inputs=True, inputs_quant_auto_fallback=inputs_quant_auto_fallback) self.accum_scale = 1 @@ -1039,18 +1026,29 @@ class RangeLinearEmbeddingWrapper(nn.Module): class RangeLinearFakeQuantWrapper(RangeLinearQuantWrapper): def __init__(self, wrapped_module, num_bits_acts, mode=LinearQuantMode.SYMMETRIC, clip_acts=ClipMode.NONE, activation_stats=None, clip_n_stds=None, clip_half_range=False, scale_approx_mult_bits=None, - fpq_module=None): + fpq_module=None, input_overrides=None, inputs_quant_auto_fallback=False): super(RangeLinearFakeQuantWrapper, self).__init__(wrapped_module, num_bits_acts, mode=mode, clip_acts=clip_acts, activation_stats=activation_stats, clip_n_stds=clip_n_stds, clip_half_range=clip_half_range, scale_approx_mult_bits=scale_approx_mult_bits, - requires_quantized_inputs=False) + input_overrides=input_overrides, + inputs_quant_auto_fallback=inputs_quant_auto_fallback) self.fpq_module = str(fpq_module) if fpq_module else None self.dtype = torch.float if self.fpq_module: self.dtype = {'16': torch.half, '32': torch.float, '64': torch.double}[self.fpq_module] self.wrapped_module.to(self.dtype) + def _prepare_input(self, idx, input): + previously_quantized = hasattr(input, 'quant_metadata') + input.quant_metadata = self._get_input_quant_metadata(idx, input) + if previously_quantized: + return input + + # "Fresh" tensor, so need to quantize and the de-quantize (because this is the fake-quant wrapper) + input_q = linear_quantize_clamp_with_metadata(input, inplace=False) + return linear_dequantize_with_metadata(input_q, inplace=True) + def quantized_forward(self, *inputs_q): inputs_q = distiller.convert_tensors_recursively_to(inputs_q, dtype=self.dtype) outputs = self.wrapped_module(*inputs_q) @@ -1214,7 +1212,8 @@ class PostTrainLinearQuantizer(Quantizer): activation_stats=activation_stats, clip_n_stds=clip_n_stds, clip_half_range=clip_half_range, scale_approx_mult_bits=scale_approx_mult_bits, - fpq_module=fpq_module) + fpq_module=fpq_module, input_overrides=input_overrides, + inputs_quant_auto_fallback=inputs_quant_auto_fallback) return RangeLinearQuantParamLayerWrapper(module, qbits.acts, qbits.wts, num_bits_accum=self.bits_accum, mode=mode, clip_acts=clip_acts, @@ -1244,7 +1243,8 @@ class PostTrainLinearQuantizer(Quantizer): activation_stats=activation_stats, clip_n_stds=clip_n_stds, clip_half_range=clip_half_range, scale_approx_mult_bits=scale_approx_mult_bits, - fpq_module=fpq_module) + fpq_module=fpq_module, input_overrides=input_overrides, + inputs_quant_auto_fallback=inputs_quant_auto_fallback) try: return wrapper_type(module, qbits.acts, mode=mode, clip_acts=clip_acts, activation_stats=activation_stats, @@ -1267,8 +1267,9 @@ class PostTrainLinearQuantizer(Quantizer): def replace_fake_quant(module, name, qbits_map, fp16=fp16, clip_acts=clip_acts, clip_n_stds=clip_n_stds, clip_half_range=clip_half_range, - scale_approx_mult_bits=scale_approx_mult_bits, fpq_module=fpq_module, fake=True, - make_identity=False): + scale_approx_mult_bits=scale_approx_mult_bits, input_overrides=None, + inputs_quant_auto_fallback=inputs_quant_auto_fallback, + fpq_module=fpq_module, fake=True, make_identity=False): if isinstance(module, (nn.ReLU, nn.ReLU6)) and make_identity: named_modules = OrderedDict(self.model.named_modules()) pred = self.adjacency_map[name].predecessors[0].name @@ -1288,7 +1289,8 @@ class PostTrainLinearQuantizer(Quantizer): activation_stats=self.model_activation_stats.get(norm_name, None), clip_n_stds=clip_n_stds, clip_half_range=clip_half_range, scale_approx_mult_bits=scale_approx_mult_bits, - fpq_module=fpq_module) + fpq_module=fpq_module, input_overrides=input_overrides, + inputs_quant_auto_fallback=inputs_quant_auto_fallback) self.clip_acts = clip_acts self.clip_n_stds = clip_n_stds -- GitLab