From cdc1775f64cf081dc426871978a9ff1250206cda Mon Sep 17 00:00:00 2001 From: Guy Jacob <guy.jacob@intel.com> Date: Thu, 6 Feb 2020 16:55:35 +0200 Subject: [PATCH] Convert Distiller PTQ models to "native" PyTorch PTQ (#458) Convert Distiller PTQ models to "native" PyTorch PTQ (#458) * New API: distiller.quantization.convert_distiller_ptq_model_to_pytorch() * Can also be called from PostTrainLinearQuantizer instance: quantizer.convert_to_pytorch() * Can also trigger from command line in image classification sample * Can save/load converted modules via apputils.load/save_checkpoint * Added Jupyter notebook tutorial * Converted modules have only the absolutely necessary quant-dequant operations. For a fully quantized model, this means just quantization of model input and de-quantization of model output. If a user keeps specific internal layers in FP32, quant-dequant operations are added as needed * Can configure either 'fbgemm' or 'qnnpack' backend. For 'fbgemm' we take care of preventing overflows (aka "reduce_range" in the PyTorch API) --- distiller/apputils/checkpoint.py | 5 + distiller/apputils/image_classifier.py | 25 +- distiller/models/__init__.py | 3 +- distiller/quantization/__init__.py | 2 + .../quantization/pytorch_quant_conversion.py | 436 ++++++++++++++++ distiller/quantization/quantizer.py | 8 +- distiller/quantization/range_linear.py | 170 +++++++ .../post_train_quant/command_line.md | 14 +- .../post_train_quant_convert_pytorch.ipynb | 481 ++++++++++++++++++ tests/test_ptq_pytorch_convert.py | 126 +++++ 10 files changed, 1264 insertions(+), 6 deletions(-) create mode 100644 distiller/quantization/pytorch_quant_conversion.py create mode 100644 jupyter/post_train_quant_convert_pytorch.ipynb create mode 100644 tests/test_ptq_pytorch_convert.py diff --git a/distiller/apputils/checkpoint.py b/distiller/apputils/checkpoint.py index 95cda26..70e9de0 100755 --- a/distiller/apputils/checkpoint.py +++ b/distiller/apputils/checkpoint.py @@ -29,6 +29,7 @@ from tabulate import tabulate import torch import distiller from distiller.utils import normalize_module_name +import distiller.quantization as quantization msglogger = logging.getLogger() @@ -224,6 +225,10 @@ def load_checkpoint(model, chkpt_file, optimizer=None, quantizer = qmd['type'](model, **qmd['params']) quantizer.prepare_model(qmd['dummy_input']) + if qmd.get('pytorch_convert', False): + msglogger.info('Converting Distiller PTQ model to PyTorch quantization API') + model = quantization.convert_distiller_ptq_model_to_pytorch(model, dummy_input=qmd['dummy_input']) + if normalize_dataparallel_keys: checkpoint['state_dict'] = {normalize_module_name(k): v for k, v in checkpoint['state_dict'].items()} anomalous_keys = model.load_state_dict(checkpoint['state_dict'], strict) diff --git a/distiller/apputils/image_classifier.py b/distiller/apputils/image_classifier.py index 3dd6339..845a96f 100755 --- a/distiller/apputils/image_classifier.py +++ b/distiller/apputils/image_classifier.py @@ -405,7 +405,7 @@ def _init_learner(args): optimizer = None msglogger.info('\nreset_optimizer flag set: Overriding resumed optimizer and resetting epoch count to 0') - if optimizer is None: + if optimizer is None and not args.evaluate: optimizer = torch.optim.SGD(model.parameters(), lr=args.lr, momentum=args.momentum, weight_decay=args.weight_decay) msglogger.debug('Optimizer Type: %s', type(optimizer)) @@ -822,6 +822,15 @@ def earlyexit_validate_stats(args): return total_top1, total_top5, losses_exits_stats +def _convert_ptq_to_pytorch(model, args): + msglogger.info('Converting Distiller PTQ model to PyTorch quantization API') + dummy_input = distiller.get_dummy_input(input_shape=model.input_shape) + model = quantization.convert_distiller_ptq_model_to_pytorch(model, dummy_input, backend=args.qe_pytorch_backend) + msglogger.debug('\nModel after conversion:\n{}'.format(model)) + args.device = 'cpu' + return model + + def evaluate_model(test_loader, model, criterion, loggers, activations_collectors=None, args=None, scheduler=None): # This sample application can be invoked to evaluate the accuracy of your model on # the test dataset. @@ -833,6 +842,9 @@ def evaluate_model(test_loader, model, criterion, loggers, activations_collector loggers = [loggers] if not args.quantize_eval: + # Handle case where a post-train quantized model was loaded, and user wants to convert it to PyTorch + if args.qe_convert_pytorch: + model = _convert_ptq_to_pytorch(model, args) return test(test_loader, model, criterion, loggers, activations_collectors, args=args) else: return quantize_and_test_model(test_loader, model, criterion, args, loggers, @@ -849,6 +861,11 @@ def quantize_and_test_model(test_loader, model, criterion, args, loggers=None, s scheduler - pass scheduler to store it in checkpoint save_flag - defaults to save both quantization statistics and checkpoint. """ + if hasattr(model, 'quantizer_metadata') and \ + model.quantizer_metadata['type'] == distiller.quantization.PostTrainLinearQuantizer: + raise RuntimeError('Trying to invoke post-training quantization on a model that has already been post-' + 'train quantized. Model was likely loaded from a checkpoint. Please run again without ' + 'passing the --quantize-eval flag') if not (args.qe_dynamic or args.qe_stats_file or args.qe_config_file): args_copy = copy.deepcopy(args) args_copy.qe_calibration = args.qe_calibration if args.qe_calibration is not None else 0.05 @@ -865,7 +882,11 @@ def quantize_and_test_model(test_loader, model, criterion, args, loggers=None, s qe_model = copy.deepcopy(model).to(args.device) quantizer = quantization.PostTrainLinearQuantizer.from_args(qe_model, args_qe) - quantizer.prepare_model(distiller.get_dummy_input(input_shape=model.input_shape)) + dummy_input = distiller.get_dummy_input(input_shape=model.input_shape) + quantizer.prepare_model(dummy_input) + + if args.qe_convert_pytorch: + qe_model = _convert_ptq_to_pytorch(qe_model, args_qe) test_res = test(test_loader, qe_model, criterion, loggers, args=args_qe) diff --git a/distiller/models/__init__.py b/distiller/models/__init__.py index 6ac57c8..c0569fc 100755 --- a/distiller/models/__init__.py +++ b/distiller/models/__init__.py @@ -146,14 +146,15 @@ def create_model(pretrained, dataset, arch, parallel=True, device_ids=None): model.features = torch.nn.DataParallel(model.features, device_ids=device_ids) else: model = torch.nn.DataParallel(model, device_ids=device_ids) + model.is_parallel = parallel else: device = 'cpu' + model.is_parallel = False # Cache some attributes which describe the model _set_model_input_shape_attr(model, arch, dataset, pretrained, cadene) model.arch = arch model.dataset = dataset - model.is_parallel = parallel return model.to(device) diff --git a/distiller/quantization/__init__.py b/distiller/quantization/__init__.py index 553877f..3eaeb80 100644 --- a/distiller/quantization/__init__.py +++ b/distiller/quantization/__init__.py @@ -21,6 +21,8 @@ from .range_linear import RangeLinearQuantWrapper, RangeLinearQuantParamLayerWra RangeLinearEmbeddingWrapper, RangeLinearFakeQuantWrapper, RangeLinearQuantMatmulWrapper from .clipped_linear import LinearQuantizeSTE, ClippedLinearQuantization, WRPNQuantizer, DorefaQuantizer, PACTQuantizer from .q_utils import * +from .pytorch_quant_conversion import convert_distiller_ptq_model_to_pytorch, distiller_qparams_to_pytorch, \ + distiller_quantized_tensor_to_pytorch del quantizer del range_linear diff --git a/distiller/quantization/pytorch_quant_conversion.py b/distiller/quantization/pytorch_quant_conversion.py new file mode 100644 index 0000000..98501b4 --- /dev/null +++ b/distiller/quantization/pytorch_quant_conversion.py @@ -0,0 +1,436 @@ +# +# Copyright (c) 2020 Intel Corporation +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import torch +import torch.nn as nn +import torch.nn.quantized as nnq +from collections import OrderedDict +import warnings +from copy import deepcopy + +import distiller +from .q_utils import LinearQuantMode, is_linear_quant_mode_symmetric + + +def need_reduce_range(distiller_quant_mode, torch_dtype): + return torch.backends.quantized.engine == 'fbgemm' and not(is_linear_quant_mode_symmetric(distiller_quant_mode) and + torch_dtype == torch.quint8) + + +def distiller_qparams_to_pytorch(scale, zp, num_bits, distiller_mode, dest_dtype, reduce_range=False): + """ + Convert quantization parameters (scale and zero-point) calculated by Distiller APIs to quantization parameters + compatible with PyTorch quantization APIs. + + By "calculated with Distiller APIs" we mean calculated using either of: + * distiller.quantization.symmetric_linear_quantization_params + * distiller.quantization.asymmetric_linear_quantization_params + + The main differences between quantization parameters as calculated by Distiller and PyTorch: + * pytorch_scale = 1 / distiller_scale + * pytorch_zero_point = -distiller_zero_point + + Args: + scale (torch.Tensor): Scale factor calcualted by Distiller + zp (torch.Tensor): Zero point calcualted by Distiller + num_bits (int): Number of bits used for quantization in Distiller + distiller_mode (distiller.quantization.LinearQuantMode): The quantization mode used in Distiller + dest_dtype (torch.dtype): PyTorch quantized dtype to convert to. Must be one of: torch.quint8, torch.qint8 + reduce_range (bool): Reduces the range of the quantized data type by 1 bit. This should mainly be used for + quantized activations with the "fbgemm" PyTorch backend - it prevents overflows. See: + https://github.com/pytorch/pytorch/blob/fde94e75568b527b424b108c272793e096e8e471/torch/quantization/observer.py#L294 + + Returns: + Tuple of (scale, zero_point) which are compatible with PyTorch quantization API + """ + assert dest_dtype in (torch.qint8, torch.quint8), 'Must specify one of the quantized PyTorch dtypes' + + distiller_symmetric = is_linear_quant_mode_symmetric(distiller_mode) + if distiller_symmetric and dest_dtype == torch.quint8: + reduce_range = False + + distiller_asym_signed = distiller_mode == LinearQuantMode.ASYMMETRIC_SIGNED + + if reduce_range: + assert num_bits == 8, 'reduce_range needed only when num_bits == 8' + if distiller_symmetric and dest_dtype == torch.quint8: + raise NotImplementedError('reduce_range + symmetric + quint8 not supported in PyTorch') + num_bits = 7 + if distiller_symmetric: + ratio = 63. / 127. + else: + ratio = 127. / 255. + zp_offset = 128 if distiller_asym_signed else 0 + zp = ((zp - zp_offset) * ratio + zp_offset / 2).round() + scale = scale * ratio + + scale = scale.cpu().squeeze() + zp = zp.cpu().squeeze().long() + + # Distiller scale is the reciprocal of PyTorch scale + scale_torch = 1. / scale + + n_bins_half = 2 ** (num_bits - 1) + + if distiller_symmetric: + # In Distiller symmetric is always signed with zero-point = 0, but in PyTorch it can be + # unsigned in which case we offset the zero-point to the middle of the quantized range + zp_torch = zp if dest_dtype == torch.qint8 else torch.full_like(zp, n_bins_half) + else: + pytorch_signed = dest_dtype == torch.qint8 + if distiller_asym_signed and not pytorch_signed: + zp = zp - n_bins_half + elif not distiller_asym_signed and pytorch_signed: + zp = zp + n_bins_half + # Distiller subtracts the zero-point when quantizing, PyTorch adds it. + # So we negate the zero-point calculated in Distiller + zp_torch = -zp + return scale_torch, zp_torch + + +def distiller_quantized_tensor_to_pytorch(tensor: torch.Tensor, scale, zp, num_bits, distiller_mode, dest_dtype, + per_channel=False, channel_dim=0): + """ + Convert a tensor quantized with quantization parameters calculated by Distiller to a PyTorch "native" quantized + tensor. + + We refer to quantization parameters calculated using either of: + * distiller.quantization.symmetric_linear_quantization_params + * distiller.quantization.asymmetric_linear_quantization_params + + And to tensors quantized using either of: + * distiller.quantization.linear_quantize + * distiller.quantization.linear_quantize_clamp + + Args: + tensor (torch.Tensor): The tensor quantized in Distiller + scale (torch.Tensor): Scale factor calcualted by Distiller + zp (torch.Tensor): Zero point calcualted by Distiller + num_bits (int): Number of bits used for quantization in Distiller + distiller_mode (distiller.quantization.LinearQuantMode): The quantization mode used in Distiller + dest_dtype (torch.dtype): PyTorch quantized dtype to convert to. Must be one of: torch.quint8, torch.qint8 + per_channel (bool): Flag in indicating if tensor was quantized per-channel + channel_dim (int): If per_channel is set, this indicates the dimension of the channel in the tensor + + Returns: + PyTorch quantized tensor (dtype one of torch.quint8 / torch.qint8 / torch.qint32) + """ + assert (tensor == tensor.int()).all(), 'Tensor does not appear to be quantized' + converted_scale, converted_zp = distiller_qparams_to_pytorch(scale, zp, num_bits, distiller_mode, dest_dtype, + reduce_range=False) + zp_diff = -converted_zp.view(zp.shape) - zp + + if dest_dtype == torch.quint8: + temp_dtype = torch.uint8 + elif dest_dtype == torch.qint8: + temp_dtype = torch.int8 + else: # dest_dtype == torch.qint32: + temp_dtype = torch.int32 + tensor = (tensor - zp_diff).to(temp_dtype) + if per_channel: + return torch._make_per_channel_quantized_tensor(tensor, converted_scale, converted_zp, channel_dim) + return torch._make_per_tensor_quantized_tensor(tensor, converted_scale, converted_zp) + + +def _ptq_convert_pass_replace_range_linear_wrappers(module): + # Hacky deferred import for now to workaround circular dependency + # TODO: Proper fix + from distiller.quantization import RangeLinearQuantWrapper + + reassign = OrderedDict() + for n, m in module.named_children(): + new_m = m + if isinstance(m, distiller.quantization.RangeLinearQuantWrapper): + new_m = m.to_pytorch_quant(need_reduce_range(m.output_quant_settings.quant_mode, torch.quint8)) + + requires_quantized_inputs = not (isinstance(new_m, nn.Sequential) and + isinstance(new_m[0], ConditionalDeQuantizeWrapper)) + + if requires_quantized_inputs: + d = OrderedDict() + for idx, qmd in m.inputs_quant_metadata_fallback.items(): + qset = m.inputs_quant_settings_overrides.get(idx, m.output_quant_settings) + scale, zp = distiller_qparams_to_pytorch(qmd.scale, qmd.zero_point, qset.num_bits, + qset.quant_mode, torch.quint8, + need_reduce_range(qset.quant_mode, torch.quint8)) + d[idx] = (scale, zp, torch.quint8) + new_m = ConditionalQuantizeWrapper(new_m, d) + elif distiller.has_children(m): + new_m = _ptq_convert_pass_replace_range_linear_wrappers(m) + elif not isinstance(m, nn.Identity): + # Module not quantized in Distiller, possibly need to de-quant input + new_m = ConditionalDeQuantizeWrapper(m) + reassign[n] = new_m + + for n, new_m in reassign.items(): + module._modules[n] = new_m + + return module + + +def patch_model_output_dequant(model): + def patched_forward(self, input): + out = self._original_forward(input) + out = self.output_dequant(out) + return out + + model.add_module('output_dequant', nnq.DeQuantize()) + model._original_forward = model.forward + # https://stackoverflow.com/questions/972/adding-a-method-to-an-existing-object-instance#comment66379065_2982 + model.forward = patched_forward.__get__(model) + + +def _ptq_convert_pass_remove_redundant_quant_dequant(model, dummy_input): + def quantize_wrapper_check_hook(module, inputs): + if not isinstance(module, ConditionalQuantize): + return + q_inputs = [] + for idx, t in enumerate(inputs): + if not isinstance(t, torch.Tensor): + continue + if t.is_quantized: + q_inputs.append(idx) + module.already_quantized = q_inputs + + def dequant_wrapper_check_hook(module, input): + if not isinstance(module, ConditionalDeQuantize): + return + module.any_quantized = False + + def check_recursively(x): + if isinstance(x, torch.Tensor) and x.is_quantized: + module.any_quantized = True + elif isinstance(x, (tuple, list)): + for item in x: + check_recursively(item) + + check_recursively(input) + + def cleanup(module): + reassign = OrderedDict() + for n, m in module.named_children(): + new_m = m + if isinstance(m, ConditionalQuantizeWrapper): + for idx in m.quant.already_quantized: + if str(idx) in m.quant.quantizers: + m.quant.quantizers.pop(str(idx)) + if len(m.quant.quantizers) == 0: + new_m = m.wrapped + elif isinstance(m, ConditionalDeQuantizeWrapper): + if not m.dequant.any_quantized: + new_m = m.wrapped + elif distiller.has_children(m): + cleanup(m) + reassign[n] = new_m + for n, new_m in reassign.items(): + module._modules[n] = new_m + + return module + + handles = [] + for m in model.modules(): + if isinstance(m, ConditionalQuantize): + handles.append(m.register_forward_pre_hook(quantize_wrapper_check_hook)) + elif isinstance(m, ConditionalDeQuantize): + handles.append(m.register_forward_pre_hook(dequant_wrapper_check_hook)) + out = model(dummy_input) + for h in handles: + h.remove() + + model = cleanup(model) + + if out.is_quantized: + patch_model_output_dequant(model) + + return model + + +def convert_distiller_ptq_model_to_pytorch(model, dummy_input, backend='fbgemm'): + """ + Convert a model quantized using distiller.quantization.PostTrainLinearQuantizer to model comprised solely of + native PyTorch static post-training quantization modules and operators. + + In the current implementation this conversion CANNOT be done in-place. + + Conversion is done in 2 passes: + * First pass: Replace all RangeLinearQuantWrapper modules with a quantize operation followed by the respective + native PyTorch module. Modules that weren't quantized by Distiller are wrapped with a de-quantize operation. + * Second pass: Perform dummy forward pass over the model and remove redundant de-quant --> quant sequences. + + The converted model returns a de-quantized output. If the last layer of the model is quantized, then an extra + dequantize module will be added to the model. This extra module is named 'output_dequant', and the model's + forward method is patched to execute this module after the main model. + NOTE: This assumes the model produces a single output tensor. In other cases the results are unexpected. + + NOTE: The converted model will be on the CPU, and non-parallel (that is - without nn.DataParallel modules) + + Args: + model (torch.nn.Module): The model to be converted + dummy_input (torch.nn.Tensor): A tensor in the shape expected by the model, required for the second pass + of the conversion + backend (str): The PyTorch quantization backend to use. Currently supported values: 'fbgemm', 'qnnpack' + + Returns: + The converted model + """ + # Hacky deferred import for now to workaround circular dependency + # TODO: Proper fix + from distiller.quantization import PostTrainLinearQuantizer + if not hasattr(model, 'quantizer_metadata') or model.quantizer_metadata['type'] != PostTrainLinearQuantizer: + raise ValueError('Conversion to PyTorch native quantization supported only for models quantized ' + 'using distiller.quantization.PostTrainLinearQuantizer') + + if dummy_input is None or not isinstance(dummy_input, torch.Tensor): + raise ValueError('Valid dummy input tensor required for converting PTQ model to PyTorch') + + backends = ('fbgemm', 'qnnpack') + if backend not in backends: + raise ValueError('{} is not a supported PyTorch quantization backend. Supported: {}'.format(backend, backends)) + torch.backends.quantized.engine = backend + + # TODO: Add in-place option. Not totally straight-forward because of the output dequantization + # Can monkey-patch instead of creating a Sequential, then it can really be in-place + + # Save quantizer metadata so we can re-attach it to the model after conversion, which enables loading the + # converted model from a checkpoint + quantizer_metadata = deepcopy(model.quantizer_metadata) + model = distiller.make_non_parallel_copy(model).cpu() + + # First pass + model = _ptq_convert_pass_replace_range_linear_wrappers(model) + + # Second pass + model = _ptq_convert_pass_remove_redundant_quant_dequant(model, dummy_input) + + # This is used when loading the model from a checkpoint, to indicate that conversion needs to be applied + quantizer_metadata['pytorch_convert'] = True + model.quantizer_metadata = quantizer_metadata + + return model + + +class QFunctionalWrapper(nn.Module): + def __init__(self): + super(QFunctionalWrapper, self).__init__() + self.qfunc = nnq.QFunctional() + + +class QFunctionalAdd(QFunctionalWrapper): + def __init__(self): + super(QFunctionalAdd, self).__init__() + + def forward(self, x, y): + return self.qfunc.add(x, y) + + +class QFunctionalAddScalar(QFunctionalWrapper): + def __init__(self): + super(QFunctionalAddScalar, self).__init__() + + def forward(self, x, y): + return self.qfunc.add_scalar(x, y) + + +class QFunctionalMul(QFunctionalWrapper): + def __init__(self): + super(QFunctionalMul, self).__init__() + + def forward(self, x, y): + return self.qfunc.mul(x, y) + + +class QFunctionalMulScalar(QFunctionalWrapper): + def __init__(self): + super(QFunctionalMulScalar, self).__init__() + + def forward(self, x, y): + return self.qfunc.mul_scalar(x, y) + + +class QFunctionalCat(QFunctionalWrapper): + def __init__(self, dim=0): + super(QFunctionalCat, self).__init__() + self.dim = dim + + def forward(self, *x): + return self.qfunc.cat(x, self.dim) + + +class QFunctionalAddRelu(QFunctionalWrapper): + def __init__(self): + super(QFunctionalAddRelu, self).__init__() + + def forward(self, x, y): + return self.qfunc.add_relu(x, y) + + +class ConditionalDeQuantize(nn.Module): + def __init__(self): + super(ConditionalDeQuantize, self).__init__() + + def forward(self, *inputs): + def dequant_recursively(x): + if isinstance(x, torch.Tensor): + return x.dequantize() if x.is_quantized else x + if isinstance(x, (tuple, list)): + return type(x)(dequant_recursively(item) for item in x) + return x + outputs = dequant_recursively(inputs) + return outputs + + +class ConditionalDeQuantizeWrapper(nn.Module): + def __init__(self, wrapped_module): + super(ConditionalDeQuantizeWrapper, self).__init__() + self.dequant = ConditionalDeQuantize() + self.wrapped = wrapped_module + + def forward(self, *inputs): + out = self.dequant(*inputs) + out = self.wrapped(*out) + return out + + +class ConditionalQuantize(nn.Module): + def __init__(self, inputs_to_qparams_map): + super(ConditionalQuantize, self).__init__() + self.quantizers = nn.ModuleDict() + for idx, qparams in inputs_to_qparams_map.items(): + self.quantizers[str(idx)] = nnq.Quantize(*qparams) + + def forward(self, *inputs): + q_inputs = [] + for idx, item in enumerate(inputs): + idx_str = str(idx) + if idx_str in self.quantizers: + assert isinstance(item, torch.Tensor), 'Trying to quantize a non-Tensor object' + if not item.is_quantized: + item = self.quantizers[idx_str](item) + q_inputs.append(item) + # return q_inputs[0] if len(q_inputs) == 1 else tuple(q_inputs) + return tuple(q_inputs) + + +class ConditionalQuantizeWrapper(nn.Module): + def __init__(self, wrapped_module, inputs_to_qparams_map): + super(ConditionalQuantizeWrapper, self).__init__() + self.quant = ConditionalQuantize(inputs_to_qparams_map) + self.wrapped = wrapped_module + + def forward(self, *inputs): + out = self.quant(*inputs) + out = self.wrapped(*out) + return out diff --git a/distiller/quantization/quantizer.py b/distiller/quantization/quantizer.py index 1b33530..f4bc448 100644 --- a/distiller/quantization/quantizer.py +++ b/distiller/quantization/quantizer.py @@ -190,6 +190,8 @@ class Quantizer(object): self.modules_processed = OrderedDict() self.modules_processed_args = OrderedDict() + self.prepared = False + def _add_qbits_entry(self, module_name, module_type, qbits): if module_type not in [nn.Conv2d, nn.Conv3d, nn.Linear, nn.Embedding]: # For now we support weights quantization only for Conv, FC and Embedding layers (so, for example, we don't @@ -247,7 +249,7 @@ class Quantizer(object): self.params_to_quantize.append(_ParamToQuant(module, module_name, fp_attr_name, param_name, n_bits)) param_full_name = '.'.join([module_name, param_name]) - msglogger.info( + msglogger.debug( "Parameter '{0}' will be quantized to {1} bits".format(param_full_name, n_bits)) # If an optimizer was passed, assume we need to update it @@ -262,7 +264,9 @@ class Quantizer(object): distiller.assign_layer_fq_names(self.model) - msglogger.info('Quantized model:\n\n{0}\n'.format(self.model)) + self.prepared = True + + msglogger.debug('Quantized model:\n\n{0}\n'.format(self.model)) def _pre_prepare_model(self, dummy_input): pass diff --git a/distiller/quantization/range_linear.py b/distiller/quantization/range_linear.py index 5e3c7f8..ef92314 100644 --- a/distiller/quantization/range_linear.py +++ b/distiller/quantization/range_linear.py @@ -33,6 +33,11 @@ from .q_utils import * from .sim_bn_fold import SimulatedFoldedBatchNorm import distiller.modules import distiller.model_transforms as mt +from . import pytorch_quant_conversion as pytqc + +import torch.quantization +import torch.nn.quantized as nnq +import torch.nn.intrinsic.quantized as nniq msglogger = logging.getLogger() @@ -326,6 +331,10 @@ def add_post_train_quant_args(argparser): 'this number of bits the integer multiplier') group.add_argument('--qe-save-fp-weights', action='store_true', help='Allow weights requantization.') + group.add_argument('--qe-convert-pytorch', '--qept', action='store_true', + help='Convert the model to PyTorch native post-train quantization modules') + group.add_argument('--qe-pytorch-backend', default='fbgemm', choices=['fbgemm', 'qnnpack'], + help='When --qe-convert-pytorch is set, specifies the PyTorch quantization backend to use') stats_group = group.add_mutually_exclusive_group() stats_group.add_argument('--qe-stats-file', type=str, metavar='PATH', @@ -657,6 +666,15 @@ class RangeLinearQuantWrapper(nn.Module): """ raise NotImplementedError + def to_pytorch_quant(self, reduce_range): + assert self.output_quant_settings.num_bits == 8, \ + 'Conversion to PyTorch PTQ supported only for 8-bit quantization' + assert self.preset_act_stats, 'Conversion to PyTorch PTQ supported only for PTQ wrappers with activation stats' + return self._convert_to_pytorch_quant(reduce_range) + + def _convert_to_pytorch_quant(self, reduce_range): + raise NotImplementedError + def extra_repr(self): if self.output_quant_settings.num_bits is None: return 'output_quant_settings=Not_Quantized' @@ -952,6 +970,63 @@ class RangeLinearQuantParamLayerWrapper(RangeLinearQuantWrapper): requant_scale = approx_scale_as_mult_and_shift(requant_scale, self.scale_approx_mult_bits) return requant_scale, output_zero_point + def _convert_to_pytorch_quant(self, reduce_range): + wrapped = self.wrapped_module + supported = (nn.Conv2d, nn.Linear) + # Tuple of module type and flag for relu fusing + mapping = { + (nn.Linear, False): nnq.Linear, + (nn.Linear, True): nniq.LinearReLU, + (nn.Conv2d, False): nnq.Conv2d, + (nn.Conv2d, True): nniq.ConvReLU2d + } + if nn.Conv3d in torch.quantization.DEFAULT_MODULE_MAPPING: + # Conv3D supported only from PyTorch 1.4 + supported += nn.Conv3d, + mapping.update({ + (nn.Conv3d, False): nnq.Conv3d, + (nn.Conv3d, True): nniq.ConvReLU3d, + }) + assert isinstance(wrapped, supported), \ + 'Conversion to PyTorch PTQ supported only for {}'.format(','.join(supported)) + assert self.wts_quant_settings.num_bits == 8, 'Conversion to PyTorch PTQ supported only for 8-bit quantization' + + # Convert weights - required by PyTorch to be signed 8-bit (torch.qint8) + q_weight = pytqc.distiller_quantized_tensor_to_pytorch(wrapped.weight.clone().detach(), + self.w_scale, self.w_zero_point, + self.wts_quant_settings.num_bits, + self.wts_quant_settings.quant_mode, torch.qint8, + self.wts_quant_settings.per_channel, 0) + + # PyTorch PTQ modules expect the bias in FP32, we need to dequantize if necessary + # With Distiller PTQ the bias is only quantized on the first forward - we do a crude check if it has + # been quantized or not + fp_bias = wrapped.bias.clone().detach() + if self.has_bias: + bias_quantized = (fp_bias == fp_bias.int()).all() + if bias_quantized: + fp_bias = linear_dequantize(fp_bias, self.accum_scale.squeeze(), 0, True) + + pytorch_cls = mapping[(type(wrapped), self.clip_half_range)] + if isinstance(wrapped, nn.Linear): + pytorch_module = pytorch_cls(wrapped.in_features, wrapped.out_features, wrapped.bias is not None) + else: + pytorch_module = pytorch_cls(wrapped.in_channels, wrapped.out_channels, wrapped.kernel_size, + wrapped.stride, wrapped.padding, wrapped.dilation, wrapped.groups, + wrapped.bias is not None, wrapped.padding_mode) + + pytorch_module.set_weight_bias(q_weight, fp_bias) + + # Convert activations qparams - required by PyTorch to be unsigned 8-bit (torch.quint8) + out_scale, out_zp = pytqc.distiller_qparams_to_pytorch(self.output_scale, self.output_zero_point, + self.output_quant_settings.num_bits, + self.output_quant_settings.quant_mode, torch.quint8, + reduce_range) + pytorch_module.scale = float(out_scale) + pytorch_module.zero_point = int(out_zp) + + return pytorch_module + def extra_repr(self): tmpstr = 'weights_quant_settings={0}\n'.format(self.wts_quant_settings) tmpstr += super(RangeLinearQuantParamLayerWrapper, self).extra_repr() @@ -1026,6 +1101,19 @@ class RangeLinearQuantMatmulWrapper(RangeLinearQuantWrapper): requant_scale = approx_scale_as_mult_and_shift(requant_scale, self.scale_approx_mult_bits) return requant_scale, output_zero_point + def _convert_to_pytorch_quant(self, reduce_range): + # Convert activations qparams - required by PyTorch to be unsigned 8-bit (torch.quint8) + scale, zp = pytqc.distiller_qparams_to_pytorch(self.output_scale, self.output_zero_point, + self.output_quant_settings.num_bits, + self.output_quant_settings.quant_mode, torch.quint8, + reduce_range) + modules = [self.wrapped_module, nnq.Quantize(float(scale), int(zp), torch.quint8)] + if self.clip_half_range: + # The scale factor calculated in Distiller already considers the ReLU, so it's OK to apply the + # ReLU after quantization + modules.append(nnq.ReLU()) + return modules + class NoStatsError(NotImplementedError): pass @@ -1065,6 +1153,21 @@ class RangeLinearQuantConcatWrapper(RangeLinearQuantWrapper): # Nothing to do here, since we already re-quantized in quantized_forward prior to the actual concatenation return 1., self.output_zero_point + def _convert_to_pytorch_quant(self, reduce_range): + # Convert activations qparams - required by PyTorch to be unsigned 8-bit (torch.quint8) + scale, zp = pytqc.distiller_qparams_to_pytorch(self.output_scale, self.output_zero_point, + self.output_quant_settings.num_bits, + self.output_quant_settings.quant_mode, torch.quint8, + reduce_range) + m = pytqc.QFunctionalCat(self.wrapped_module.dim) + m.qfunc.scale = float(scale) + m.qfunc.zp = int(zp) + if self.clip_half_range: + # The scale factor calculated in Distiller already considers the ReLU, so it's OK to apply the + # ReLU after quantization + m = nn.Sequential(m, nnq.ReLU()) + return m + class RangeLinearQuantEltwiseAddWrapper(RangeLinearQuantWrapper): def __init__(self, wrapped_module, num_bits_acts, mode=LinearQuantMode.SYMMETRIC, clip_acts=ClipMode.NONE, @@ -1101,6 +1204,17 @@ class RangeLinearQuantEltwiseAddWrapper(RangeLinearQuantWrapper): def get_accum_to_output_re_quantization_params(self, output_scale, output_zero_point): return 1., self.output_zero_point + def _convert_to_pytorch_quant(self, reduce_range): + # Convert activations qparams - required by PyTorch to be unsigned 8-bit (torch.quint8) + scale, zp = pytqc.distiller_qparams_to_pytorch(self.output_scale, self.output_zero_point, + self.output_quant_settings.num_bits, + self.output_quant_settings.quant_mode, torch.quint8, + reduce_range) + m = pytqc.QFunctionalAddRelu() if self.clip_half_range else pytqc.QFunctionalAdd() + m.qfunc.scale = float(scale) + m.qfunc.zp = int(zp) + return m + class RangeLinearQuantEltwiseMultWrapper(RangeLinearQuantWrapper): def __init__(self, wrapped_module, num_bits_acts, mode=LinearQuantMode.SYMMETRIC, clip_acts=ClipMode.NONE, @@ -1142,6 +1256,21 @@ class RangeLinearQuantEltwiseMultWrapper(RangeLinearQuantWrapper): requant_scale = approx_scale_as_mult_and_shift(requant_scale, self.scale_approx_mult_bits) return requant_scale, output_zero_point + def _convert_to_pytorch_quant(self, reduce_range): + # Convert activations qparams - requirec by PyTorch to be unsigned 8-bit (torch.quint8) + scale, zp = pytqc.distiller_qparams_to_pytorch(self.output_scale, self.output_zero_point, + self.output_quant_settings.num_bits, + self.output_quant_settings.quant_mode, torch.quint8, + reduce_range) + m = pytqc.QFunctionalMul() + m.qfunc.scale = float(scale) + m.qfunc.zp = int(zp) + if self.clip_half_range: + # The scale factor calculated in Distiller already considers the ReLU, so it's OK to apply the + # ReLU after quantization + m = nn.Sequential(m, nnq.ReLU()) + return m + class FPWrapper(nn.Module): """ @@ -1342,6 +1471,34 @@ class RangeLinearFakeQuantWrapper(RangeLinearQuantWrapper): def get_accum_to_output_re_quantization_params(self, output_scale, output_zero_point): return output_scale, output_zero_point + def _convert_to_pytorch_quant(self, reduce_range): + # A few PyTorch modules support quantized inputs/outputs + supported = { + nn.ReLU: nnq.ReLU(), + nn.ReLU6: nnq.ReLU6(), + nn.AvgPool2d: self.wrapped_module, + nn.AdaptiveAvgPool2d: self.wrapped_module, + nn.MaxPool2d: self.wrapped_module + } + q_module = supported.get(type(self.wrapped_module), None) + if q_module is None: + # No PyTorch quantized module - so fake it + # Convert activations qparams - required by PyTorch to be unsigned 8-bit (torch.quint8) + scale, zp = pytqc.distiller_qparams_to_pytorch(self.output_scale, self.output_zero_point, + self.output_quant_settings.num_bits, + self.output_quant_settings.quant_mode, torch.quint8, + reduce_range) + modules = [pytqc.ConditionalDeQuantizeWrapper(self.wrapped_module), + nnq.Quantize(float(scale), int(zp), torch.quint8)] + else: + modules = [self.wrapped_module] + if self.clip_half_range: + # The scale factor calculated in Distiller already considers the ReLU, so it's OK to apply the + # ReLU after quantization + modules.append(nnq.ReLU()) + + return modules[0] if len(modules) == 1 else nn.Sequential(*modules) + def extra_repr(self): tmpstr = super(RangeLinearFakeQuantWrapper, self).extra_repr() if self.dtype: @@ -1959,6 +2116,19 @@ class PostTrainLinearQuantizer(Quantizer): buffer.data = buffer.data.to(device) self.linear_quant_params = OrderedDict(self.named_linear_quant_params()) + def convert_to_pytorch(self, dummy_input, backend='fbgemm'): + """ + Convert a model quantized using distiller.quantization.PostTrainLinearQuantizer to model comprised solely of + native PyTorch static post-training quantization modules and operators. + + This is a convenience wrapper around distiller.quantization.convert_distiller_ptq_model_to_pytorch + See that function's documentation for more details. + """ + if not self.prepared: + raise RuntimeError("Must call prepare_model before attempting to convert to PyTorch") + + return pytqc.convert_distiller_ptq_model_to_pytorch(self.model, dummy_input, backend=backend) + ############################################################################### # Quantization-aware training diff --git a/examples/quantization/post_train_quant/command_line.md b/examples/quantization/post_train_quant/command_line.md index b168b52..ce2c948 100644 --- a/examples/quantization/post_train_quant/command_line.md +++ b/examples/quantization/post_train_quant/command_line.md @@ -24,13 +24,25 @@ Post-training quantization can either be configured straight from the command-li | `--qe-clip-acts` | `--qeca` | Set activations clipping mode. Choices: "none", "avg", "n_std", "gauss", "laplace" | "none" | | `--qe-clip-n-stds` | N/A | When qe-clip-acts is set to 'n_std', this is the number of standard deviations to use | None | | `--qe-no-clip-layers` | `--qencl` | List of layer names (space-separated) for which not to clip activations | '' | +| `--qe-no-quant-layers` | `--qenql` | List of layer names (space-separated) for which not to skip quantization | '' | | `--qe-per-channel` | `--qepc` | Enable per-channel quantization of weights (per output channel) | Off | | `--qe-scale-approx-bits` | `--qesab` | Enables scale factor approximation using integer multiply + bit shift, using this number of bits the integer multiplier | None | | `--qe-stats-file` | N/A | Use stats file for static quantization of activations. See details below | None | | `--qe-dynamic` | N/A | Perform dynamic quantization. See details below | None | | `--qe-config-file` | N/A | Path to YAML config file. See section above. (ignores all other --qe* arguments) | None | +| `--qe-convert-pytorch` | `--qept` | Convert the model to PyTorch native post-train quantization modules | Off | +| `--qe-pytorch-backend` | N/A | When --qe-convert-pytorch is set, specifies the PyTorch quantization backend to use. Choices: "fbgemm", "qnnpack" | Off | -(Note that these arguments can be added to any `argparse.ArgumentParser` by calling `distiller.quantization.add_post_train_quant_args()` and passing an existing parser) +### Notes + +1. These arguments can be added to any `argparse.ArgumentParser` by calling `distiller.quantization.add_post_train_quant_args()` and passing an existing parser. This is provided as a convenience only. If you are writing a script and adding these arguments, it is up to you to implement the actual functionality implied by them. +2. The `--qe-convert-pytorch` works in two settings: + * `--quantize-eval` is also set, in which case an FP32 model is first quantized using Distiller's post-training quantization flow, and then converted to a PyTorch native quantization model. + * `--quantize-eval` is not set, but a previously post-train quantized model is loaded via `--resume`. In this case, the loaded model is converted to PyTorch native quantization. + +### Conversion to PyTorch Built-in Quantization Model + +PyTorch released built-in support for quantization in version 1.3. Currently Distiller's quantization functionality is still completely separate from PyTorch's. We provide the ability to take a model which was post-train quantized with Distiller, and is comprised of `RangeLinearQuantWrapper` ## "Net-Aware" Quantization diff --git a/jupyter/post_train_quant_convert_pytorch.ipynb b/jupyter/post_train_quant_convert_pytorch.ipynb new file mode 100644 index 0000000..74e4d0a --- /dev/null +++ b/jupyter/post_train_quant_convert_pytorch.ipynb @@ -0,0 +1,481 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Convert Distiller Post-Train Quantization Models to \"Native\" PyTorch\n", + "\n", + "## Background\n", + "\n", + "As of version 1.3 PyTorch comes with built-in quantization functionality. Details are available [here](https://pytorch.org/docs/stable/quantization.html). Distiller's and PyTorch's implementations are completely unrelated. An advantage of PyTorch built-in quantization is that it offers optimized 8-bit execution on CPU and export to GLOW. PyTorch doesn't offer optimized 8-bit execution on GPU (as of version 1.4).\n", + "\n", + "At the moment we are still keeping Distiller's separate API and implementation, but we've added the capability to convert a **post-training quantization** model created in Distiller to a \"Distiller-free\" model, comprised entirely of PyTorch built-in quantized modules.\n", + "\n", + "Distiller's quantized layers are actually simulated in FP32. Hence, comparing a Distiller model running on CPU to a PyTorch built-in model, the latter will be significantly faster on CPU. However, a Distiller model on a GPU is still likely to be faster compared to a PyTorch model on CPU. So experimenting with Distiller and converting to PyTorch in the end could be useful. Milage may vary of course, depending on the actual HW setup.\n", + "\n", + "Let's see how the conversion works." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "import torch\n", + "import matplotlib.pyplot as plt\n", + "import os\n", + "import math\n", + "import torchnet as tnt\n", + "from ipywidgets import widgets, interact\n", + "from copy import deepcopy\n", + "from collections import OrderedDict\n", + "\n", + "import distiller\n", + "from distiller.models import create_model\n", + "import distiller.quantization as quant\n", + "\n", + "# Load some common code and configure logging\n", + "# We do this so we can see the logging output coming from\n", + "# Distiller function calls\n", + "%run './distiller_jupyter_helpers.ipynb'\n", + "msglogger = config_notebooks_logger()" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Create Model" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# By default, the model is moved to the GPU and parallelized (wrapped with torch.nn.DataParallel)\n", + "# If no GPU is available, a non-parallel model is created on the CPU\n", + "model = create_model(pretrained=True, dataset='imagenet', arch='resnet18', parallel=True)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Create Data Loaders\n", + "\n", + "We create separate data loaders for GPU and CPU. Set `batch_size` and `num_workers` to optimal values that match your HW setup.\n", + "\n", + "(Note we reset the seed before creating each data loader, to make sure both loaders consist of the same subset of the test set)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# We use Distiller's built-in data loading functionality for ImageNet\n", + "\n", + "distiller.set_seed(0)\n", + "\n", + "subset_size = 1.0 # To save time, can set to value < 1.0\n", + "dataset = 'imagenet'\n", + "dataset_path = os.path.expanduser('/data2/datasets/imagenet')\n", + "\n", + "batch_size_gpu = 256\n", + "num_workers_gpu = 10\n", + "_, _, test_loader_gpu, _ = distiller.apputils.load_data(\n", + " dataset, dataset_path, batch_size_gpu, num_workers_gpu,\n", + " effective_test_size=subset_size, fixed_subset=True, test_only=True)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "distiller.set_seed(0)\n", + "batch_size_cpu = 44\n", + "num_workers_cpu = 10\n", + "_, _, test_loader_cpu, _ = distiller.apputils.load_data(\n", + " dataset, dataset_path, batch_size_cpu, num_workers_cpu,\n", + " effective_test_size=subset_size, fixed_subset=True, test_only=True)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Define Evaluation Function" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "def eval_model(data_loader, model, device, print_freq=10):\n", + " print('Evaluating model')\n", + " criterion = torch.nn.CrossEntropyLoss().to(device)\n", + " \n", + " loss = tnt.meter.AverageValueMeter()\n", + " classerr = tnt.meter.ClassErrorMeter(accuracy=True, topk=(1, 5))\n", + "\n", + " total_samples = len(data_loader.sampler)\n", + " batch_size = data_loader.batch_size\n", + " total_steps = math.ceil(total_samples / batch_size)\n", + " print('{0} samples ({1} per mini-batch)'.format(total_samples, batch_size))\n", + "\n", + " # Switch to evaluation mode\n", + " model.eval()\n", + "\n", + " for step, (inputs, target) in enumerate(data_loader):\n", + " with torch.no_grad():\n", + " inputs, target = inputs.to(device), target.to(device)\n", + " # compute output from model\n", + " output = model(inputs)\n", + "\n", + " # compute loss and measure accuracy\n", + " loss.add(criterion(output, target).item())\n", + " classerr.add(output.data, target)\n", + " \n", + " if (step + 1) % print_freq == 0:\n", + " print('[{:3d}/{:3d}] Top1: {:.3f} Top5: {:.3f} Loss: {:.3f}'.format(\n", + " step + 1, total_steps, classerr.value(1), classerr.value(5), loss.mean), flush=True)\n", + " print('----------')\n", + " print('Overall ==> Top1: {:.3f} Top5: {:.3f} Loss: {:.3f}'.format(\n", + " classerr.value(1), classerr.value(5), loss.mean), flush=True)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Post-Train Quantize with Distiller" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "quant_mode = {'activations': 'ASYMMETRIC_UNSIGNED', 'weights': 'SYMMETRIC'}\n", + "stats_file = \"../examples/quantization/post_train_quant/stats/resnet18_quant_stats.yaml\"\n", + "dummy_input = distiller.get_dummy_input(input_shape=model.input_shape)\n", + "\n", + "quantizer = quant.PostTrainLinearQuantizer(\n", + " deepcopy(model), bits_activations=8, bits_parameters=8, mode=quant_mode,\n", + " model_activation_stats=stats_file, overrides=None\n", + ")\n", + "quantizer.prepare_model(dummy_input)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Convert to PyTorch Built-In" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Here we trigger the conversion via the Quantizer instance. Later on we show another way which does not\n", + "# require the quantizer\n", + "pyt_model = quantizer.convert_to_pytorch(dummy_input)\n", + "\n", + "# Note that the converted model is automatically moved to the CPU, regardless\n", + "# of the device of the Distiller model\n", + "print('Distiller model device:', distiller.model_device(quantizer.model))\n", + "print('PyTorch model device:', distiller.model_device(pyt_model))" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Run Evaluation\n", + "### Distiller Model on GPU (if available)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "if torch.cuda.is_available():\n", + " %time eval_model(test_loader_gpu, quantizer.model, 'cuda')" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Distiller Model on CPU" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "if torch.cuda.is_available():\n", + " print('Creating CPU copy of Distiller model')\n", + " cpu_model = distiller.make_non_parallel_copy(quantizer.model).cpu()\n", + "else:\n", + " cpu_model = quantizer.model\n", + "%time eval_model(test_loader_cpu, cpu_model, 'cpu', print_freq=60)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### PyTorch model in CPU\n", + "\n", + "We expect the PyTorch model on CPU to be much faster than the Distiller model on CPU" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "%time eval_model(test_loader_cpu, pyt_model, 'cpu', print_freq=60)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## For the Extra-Curious: Comparing the Models" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "1. Distiller takes care of quantizing the inputs within the quantized modules PyTorch quantized modules assume the input is already quantized. Hence, for cases where a module's input is not quantized, we explicitly add a quantization operation for the input. The first layer in the model, `conv1` in ResNet18, is such a case\n", + "2. Both Distiller and native PyTorch support fused ReLU. In Distiller, this is somewhat obscurely indicated by the `clip_half_range` attribute inside `output_quant_settings`. In PyTorch, the module type is explicitly `QuantizedConvReLU2d`." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "print('conv1\\n')\n", + "print('DISTILLER:\\n{}\\n'.format(quantizer.model.module.conv1))\n", + "print('PyTorch:\\n{}\\n'.format(pyt_model.conv1))" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Example of internal layers which don't require explicit input quantization:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "print('layer1.0.conv1')\n", + "print(pyt_model.layer1[0].conv1)\n", + "print('\\nlayer1.0.add')\n", + "print(pyt_model.layer1[0].add)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Automatic de-quantization <--> quantization in the model\n", + "\n", + "For each quantized module in the Distiller implementation, we quantize the input and de-quantize the output.\n", + "So, if the user explicitly sets \"internal\" modules to run in FP32, this is transparent to the other quantized modules (at the cost of redundant quant-dequant operations).\n", + "\n", + "When converting to PyTorch we remove these redundant operations, and keep just the required ones in case the user explicitly decided to run some modules in FP32.\n", + "\n", + "For an example, consider a ResNet \"basic block\" with a residual connection that contains a downsampling convolution. Let's see how such a block looks in our fully-quantized, converted model:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "print(pyt_model.layer2[0])" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "We can see all layers are either built-in quantized PyTorch modules, or identity operations representing fused operations. The entire block is quantized, so we don't see any quant-dequnt operations in the middle.\n", + "\n", + "Now let's create a new quantized model, and this time leave the 'downsample' module in FP32:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "overrides = OrderedDict(\n", + " [('layer2.0.downsample.0', OrderedDict([('bits_activations', None), ('bits_weights', None)]))]\n", + ")\n", + "new_quantizer = quant.PostTrainLinearQuantizer(\n", + " deepcopy(model), bits_activations=8, bits_parameters=8, mode=quant_mode,\n", + " model_activation_stats=stats_file, overrides=overrides\n", + ")\n", + "new_quantizer.prepare_model(dummy_input)\n", + "\n", + "new_pyt_model = new_quantizer.convert_to_pytorch(dummy_input)\n", + "\n", + "print(new_pyt_model.layer2[0])" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "We can see a few differences:\n", + "1. The `downsample` module now contains a de-quantize op before the actual convolution\n", + "2. The `add` module now contains a quantize op before the actual add. Note that the add operation accepts 2 inputs. In this case the first input (index 0) comes from the `conv2` module, which is quantized. The second input (index 1) comes from the `downsample` module, which we kept in FP32. So, we only need to quantized the input at index 1. We can see this is indeed what is happening, by looking at the `ModuleDict` inside the `quant` module, and noticing it has only a single key for index \"1\".\n", + "\n", + "Let's see how the `add` module would look if we also kept the `conv2` module in FP32:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "overrides = OrderedDict(\n", + " [('layer2.0.downsample.0', OrderedDict([('bits_activations', None), ('bits_weights', None)])),\n", + " ('layer2.0.conv2', OrderedDict([('bits_activations', None), ('bits_weights', None)]))]\n", + ")\n", + "new_quantizer = quant.PostTrainLinearQuantizer(\n", + " deepcopy(model), bits_activations=8, bits_parameters=8, mode=quant_mode,\n", + " model_activation_stats=stats_file, overrides=overrides\n", + ")\n", + "new_quantizer.prepare_model(dummy_input)\n", + "\n", + "new_pyt_model = new_quantizer.convert_to_pytorch(dummy_input)\n", + "\n", + "print(new_pyt_model.layer2[0].add)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "We can see that now both inputs to the add module are being quantized." + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Another API for Conversion\n", + "\n", + "In some cases we don't have the actual quantizer. For example - if the Distiller quantized module was loaded from a checkpoint. In those cases we can call a `distiller.quantization` module-level function (In fact, the Quantizer method we used earlier is a wrapper around this function).\n", + "\n", + "### Save Distiller model to checkpoint" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Save Distiller model to checkpoint and load it\n", + "distiller.apputils.save_checkpoint(0, 'resnet18', quantizer.model)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Load Checkpoint\n", + "\n", + "The model is quantized when the checkpoint is loaded" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "loaded_model = create_model(False, dataset='imagenet', arch='resnet18', parallel=True)\n", + "loaded_model = distiller.apputils.load_lean_checkpoint(loaded_model, 'checkpoint.pth.tar')" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Convert and Evaluate" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Convert\n", + "loaded_pyt_model = distiller.quantization.convert_distiller_ptq_model_to_pytorch(loaded_model, dummy_input)\n", + "\n", + "# Run evaluation\n", + "%time eval_model(test_loader_cpu, loaded_pyt_model, 'cpu', print_freq=60)\n", + "\n", + "# Cleanup\n", + "os.remove('checkpoint.pth.tar')" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.5.2" + } + }, + "nbformat": 4, + "nbformat_minor": 2 +} diff --git a/tests/test_ptq_pytorch_convert.py b/tests/test_ptq_pytorch_convert.py new file mode 100644 index 0000000..81919cd --- /dev/null +++ b/tests/test_ptq_pytorch_convert.py @@ -0,0 +1,126 @@ +# +# Copyright (c) 2020 Intel Corporation +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +import pytest + +import torch +import torch.testing + +import distiller.quantization as quantization +from distiller.quantization.range_linear import _get_quant_params_from_tensor + + +@pytest.fixture(params=[quantization.LinearQuantMode.SYMMETRIC, + quantization.LinearQuantMode.SYMMETRIC_RESTRICTED, + quantization.LinearQuantMode.ASYMMETRIC_UNSIGNED, + quantization.LinearQuantMode.ASYMMETRIC_SIGNED], + ids=['symmetric', 'symmetric_restricted', 'asym_unsigned', 'asym_signed']) +def distiller_mode(request): + return request.param + + +@pytest.fixture(params=[torch.qint8, torch.quint8], ids=['torch.qint8', 'torch.quint8']) +def torch_dtype(request): + return request.param + + +@pytest.fixture(params=[False, True], ids=['per_tensor', 'per_channel']) +def per_channel(request): + return request.param + + +@pytest.fixture(params=[False, True], ids=['reduce_off', 'reduce_on']) +def reduce_range(request): + return request.param + + +@pytest.fixture() +def num_bits(): + return 8 + + +@pytest.fixture() +def tensor(): + return torch.randn(64, 256, 7, 7) + + +def test_qparams_conversion(tensor, num_bits, distiller_mode, torch_dtype, per_channel, reduce_range): + if reduce_range: + if num_bits != 8: + return True + if quantization.is_linear_quant_mode_symmetric(distiller_mode) and torch_dtype == torch.quint8: + return True + + # Calculate quantization parameters with Distiller for number of bits BEFORE reduce_range + signed = distiller_mode != quantization.LinearQuantMode.ASYMMETRIC_UNSIGNED + distiller_scale, distiller_zp = _get_quant_params_from_tensor(tensor, num_bits, distiller_mode, + per_channel=per_channel) + + # Convert parameters to PyTorch + converted_scale, converted_zp = quantization.distiller_qparams_to_pytorch( + distiller_scale, distiller_zp, num_bits, distiller_mode, torch_dtype, reduce_range + ) + + # Quantize tensor with Distiller + # If reduce_range is set, then we actually quantize with num_bits-1 + if reduce_range: + num_bits -= 1 + distiller_scale, distiller_zp = _get_quant_params_from_tensor(tensor, num_bits, distiller_mode, + per_channel=per_channel) + restrict = distiller_mode == quantization.LinearQuantMode.SYMMETRIC_RESTRICTED + clamp_min, clamp_max = quantization.get_quantized_range(num_bits, signed=signed, signed_restrict_qrange=restrict) + distiller_q_t = quantization.linear_quantize_clamp(tensor, distiller_scale, distiller_zp, clamp_min, clamp_max) + + # Quantize with PyTorch + if per_channel: + pytorch_q_t = torch.quantize_per_channel(tensor, converted_scale, converted_zp, 0, torch_dtype) + else: + pytorch_q_t = torch.quantize_per_tensor(tensor, converted_scale, converted_zp, torch_dtype) + + # Dequantize + distiller_q_dq_t = quantization.linear_dequantize(distiller_q_t, distiller_scale, distiller_zp) + pytorch_q_dq_t = pytorch_q_t.dequantize() + + # Compare - allow of up to one quantized "bin" between the tensors + if per_channel: + for idx, scale in enumerate(converted_scale): + torch.testing.assert_allclose(distiller_q_dq_t[idx], pytorch_q_dq_t[idx], atol=scale, rtol=1e-05) + else: + torch.testing.assert_allclose(pytorch_q_dq_t, distiller_q_dq_t, atol=converted_scale, rtol=1e-05) + + +def test_quantized_tensor_conversion(tensor, num_bits, distiller_mode, torch_dtype, per_channel): + # Quantize tensor with Distiller + signed = distiller_mode != quantization.LinearQuantMode.ASYMMETRIC_UNSIGNED + distiller_scale, distiller_zp = _get_quant_params_from_tensor(tensor, num_bits, distiller_mode, + per_channel=per_channel) + restrict = distiller_mode == quantization.LinearQuantMode.SYMMETRIC_RESTRICTED + clamp_min, clamp_max = quantization.get_quantized_range(num_bits, signed=signed, signed_restrict_qrange=restrict) + distiller_q_t = quantization.linear_quantize_clamp(tensor, distiller_scale, distiller_zp, clamp_min, clamp_max) + + # Convert tensor to PyTorch + pytorch_q_t = quantization.distiller_quantized_tensor_to_pytorch( + distiller_q_t, distiller_scale, distiller_zp, num_bits, distiller_mode, torch_dtype, per_channel, 0 + ) + + # Dequantize both + distiller_q_dq_t = quantization.linear_dequantize(distiller_q_t, distiller_scale, distiller_zp) + pytorch_q_dq_t = pytorch_q_t.dequantize() + + # Compare + torch.testing.assert_allclose(pytorch_q_dq_t, distiller_q_dq_t) + + +#TODO: Add tests of full model conversion -- GitLab