From cdc1775f64cf081dc426871978a9ff1250206cda Mon Sep 17 00:00:00 2001
From: Guy Jacob <guy.jacob@intel.com>
Date: Thu, 6 Feb 2020 16:55:35 +0200
Subject: [PATCH] Convert Distiller PTQ models to "native" PyTorch PTQ (#458)

Convert Distiller PTQ models to "native" PyTorch PTQ (#458)

* New API: distiller.quantization.convert_distiller_ptq_model_to_pytorch()
* Can also be called from PostTrainLinearQuantizer instance:
    quantizer.convert_to_pytorch()
* Can also trigger from command line in image classification sample
* Can save/load converted modules via apputils.load/save_checkpoint
* Added Jupyter notebook tutorial

* Converted modules have only the absolutely necessary quant-dequant
  operations. For a fully quantized model, this means just quantization
  of model input and de-quantization of model output. If a user keeps
  specific internal layers in FP32, quant-dequant operations are added
  as needed
* Can configure either 'fbgemm' or 'qnnpack' backend. For 'fbgemm' we
  take care of preventing overflows (aka "reduce_range" in the PyTorch
  API)
---
 distiller/apputils/checkpoint.py              |   5 +
 distiller/apputils/image_classifier.py        |  25 +-
 distiller/models/__init__.py                  |   3 +-
 distiller/quantization/__init__.py            |   2 +
 .../quantization/pytorch_quant_conversion.py  | 436 ++++++++++++++++
 distiller/quantization/quantizer.py           |   8 +-
 distiller/quantization/range_linear.py        | 170 +++++++
 .../post_train_quant/command_line.md          |  14 +-
 .../post_train_quant_convert_pytorch.ipynb    | 481 ++++++++++++++++++
 tests/test_ptq_pytorch_convert.py             | 126 +++++
 10 files changed, 1264 insertions(+), 6 deletions(-)
 create mode 100644 distiller/quantization/pytorch_quant_conversion.py
 create mode 100644 jupyter/post_train_quant_convert_pytorch.ipynb
 create mode 100644 tests/test_ptq_pytorch_convert.py

diff --git a/distiller/apputils/checkpoint.py b/distiller/apputils/checkpoint.py
index 95cda26..70e9de0 100755
--- a/distiller/apputils/checkpoint.py
+++ b/distiller/apputils/checkpoint.py
@@ -29,6 +29,7 @@ from tabulate import tabulate
 import torch
 import distiller
 from distiller.utils import normalize_module_name
+import distiller.quantization as quantization
 msglogger = logging.getLogger()
 
 
@@ -224,6 +225,10 @@ def load_checkpoint(model, chkpt_file, optimizer=None,
         quantizer = qmd['type'](model, **qmd['params'])
         quantizer.prepare_model(qmd['dummy_input'])
 
+        if qmd.get('pytorch_convert', False):
+            msglogger.info('Converting Distiller PTQ model to PyTorch quantization API')
+            model = quantization.convert_distiller_ptq_model_to_pytorch(model, dummy_input=qmd['dummy_input'])
+
     if normalize_dataparallel_keys:
         checkpoint['state_dict'] = {normalize_module_name(k): v for k, v in checkpoint['state_dict'].items()}
     anomalous_keys = model.load_state_dict(checkpoint['state_dict'], strict)
diff --git a/distiller/apputils/image_classifier.py b/distiller/apputils/image_classifier.py
index 3dd6339..845a96f 100755
--- a/distiller/apputils/image_classifier.py
+++ b/distiller/apputils/image_classifier.py
@@ -405,7 +405,7 @@ def _init_learner(args):
             optimizer = None
             msglogger.info('\nreset_optimizer flag set: Overriding resumed optimizer and resetting epoch count to 0')
 
-    if optimizer is None:
+    if optimizer is None and not args.evaluate:
         optimizer = torch.optim.SGD(model.parameters(), lr=args.lr,
                                     momentum=args.momentum, weight_decay=args.weight_decay)
         msglogger.debug('Optimizer Type: %s', type(optimizer))
@@ -822,6 +822,15 @@ def earlyexit_validate_stats(args):
     return total_top1, total_top5, losses_exits_stats
 
 
+def _convert_ptq_to_pytorch(model, args):
+    msglogger.info('Converting Distiller PTQ model to PyTorch quantization API')
+    dummy_input = distiller.get_dummy_input(input_shape=model.input_shape)
+    model = quantization.convert_distiller_ptq_model_to_pytorch(model, dummy_input, backend=args.qe_pytorch_backend)
+    msglogger.debug('\nModel after conversion:\n{}'.format(model))
+    args.device = 'cpu'
+    return model
+
+
 def evaluate_model(test_loader, model, criterion, loggers, activations_collectors=None, args=None, scheduler=None):
     # This sample application can be invoked to evaluate the accuracy of your model on
     # the test dataset.
@@ -833,6 +842,9 @@ def evaluate_model(test_loader, model, criterion, loggers, activations_collector
         loggers = [loggers]
 
     if not args.quantize_eval:
+        # Handle case where a post-train quantized model was loaded, and user wants to convert it to PyTorch
+        if args.qe_convert_pytorch:
+            model = _convert_ptq_to_pytorch(model, args)
         return test(test_loader, model, criterion, loggers, activations_collectors, args=args)
     else:
         return quantize_and_test_model(test_loader, model, criterion, args, loggers,
@@ -849,6 +861,11 @@ def quantize_and_test_model(test_loader, model, criterion, args, loggers=None, s
     scheduler - pass scheduler to store it in checkpoint
     save_flag - defaults to save both quantization statistics and checkpoint.
     """
+    if hasattr(model, 'quantizer_metadata') and \
+            model.quantizer_metadata['type'] == distiller.quantization.PostTrainLinearQuantizer:
+        raise RuntimeError('Trying to invoke post-training quantization on a model that has already been post-'
+                           'train quantized. Model was likely loaded from a checkpoint. Please run again without '
+                           'passing the --quantize-eval flag')
     if not (args.qe_dynamic or args.qe_stats_file or args.qe_config_file):
         args_copy = copy.deepcopy(args)
         args_copy.qe_calibration = args.qe_calibration if args.qe_calibration is not None else 0.05
@@ -865,7 +882,11 @@ def quantize_and_test_model(test_loader, model, criterion, args, loggers=None, s
         qe_model = copy.deepcopy(model).to(args.device)
 
     quantizer = quantization.PostTrainLinearQuantizer.from_args(qe_model, args_qe)
-    quantizer.prepare_model(distiller.get_dummy_input(input_shape=model.input_shape))
+    dummy_input = distiller.get_dummy_input(input_shape=model.input_shape)
+    quantizer.prepare_model(dummy_input)
+
+    if args.qe_convert_pytorch:
+        qe_model = _convert_ptq_to_pytorch(qe_model, args_qe)
 
     test_res = test(test_loader, qe_model, criterion, loggers, args=args_qe)
 
diff --git a/distiller/models/__init__.py b/distiller/models/__init__.py
index 6ac57c8..c0569fc 100755
--- a/distiller/models/__init__.py
+++ b/distiller/models/__init__.py
@@ -146,14 +146,15 @@ def create_model(pretrained, dataset, arch, parallel=True, device_ids=None):
                 model.features = torch.nn.DataParallel(model.features, device_ids=device_ids)
             else:
                 model = torch.nn.DataParallel(model, device_ids=device_ids)
+        model.is_parallel = parallel
     else:
         device = 'cpu'
+        model.is_parallel = False
 
     # Cache some attributes which describe the model
     _set_model_input_shape_attr(model, arch, dataset, pretrained, cadene)
     model.arch = arch
     model.dataset = dataset
-    model.is_parallel = parallel
     return model.to(device)
 
 
diff --git a/distiller/quantization/__init__.py b/distiller/quantization/__init__.py
index 553877f..3eaeb80 100644
--- a/distiller/quantization/__init__.py
+++ b/distiller/quantization/__init__.py
@@ -21,6 +21,8 @@ from .range_linear import RangeLinearQuantWrapper, RangeLinearQuantParamLayerWra
     RangeLinearEmbeddingWrapper, RangeLinearFakeQuantWrapper, RangeLinearQuantMatmulWrapper
 from .clipped_linear import LinearQuantizeSTE, ClippedLinearQuantization, WRPNQuantizer, DorefaQuantizer, PACTQuantizer
 from .q_utils import *
+from .pytorch_quant_conversion import convert_distiller_ptq_model_to_pytorch, distiller_qparams_to_pytorch, \
+    distiller_quantized_tensor_to_pytorch
 
 del quantizer
 del range_linear
diff --git a/distiller/quantization/pytorch_quant_conversion.py b/distiller/quantization/pytorch_quant_conversion.py
new file mode 100644
index 0000000..98501b4
--- /dev/null
+++ b/distiller/quantization/pytorch_quant_conversion.py
@@ -0,0 +1,436 @@
+#
+# Copyright (c) 2020 Intel Corporation
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#      http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+import torch
+import torch.nn as nn
+import torch.nn.quantized as nnq
+from collections import OrderedDict
+import warnings
+from copy import deepcopy
+
+import distiller
+from .q_utils import LinearQuantMode, is_linear_quant_mode_symmetric
+
+
+def need_reduce_range(distiller_quant_mode, torch_dtype):
+    return torch.backends.quantized.engine == 'fbgemm' and not(is_linear_quant_mode_symmetric(distiller_quant_mode) and
+                                                               torch_dtype == torch.quint8)
+
+
+def distiller_qparams_to_pytorch(scale, zp, num_bits, distiller_mode, dest_dtype, reduce_range=False):
+    """
+    Convert quantization parameters (scale and zero-point) calculated by Distiller APIs to quantization parameters
+    compatible with PyTorch quantization APIs.
+
+    By "calculated with Distiller APIs" we mean calculated using either of:
+      * distiller.quantization.symmetric_linear_quantization_params
+      * distiller.quantization.asymmetric_linear_quantization_params
+
+    The main differences between quantization parameters as calculated by Distiller and PyTorch:
+      * pytorch_scale = 1 / distiller_scale
+      * pytorch_zero_point = -distiller_zero_point
+
+    Args:
+        scale (torch.Tensor): Scale factor calcualted by Distiller
+        zp (torch.Tensor): Zero point calcualted by Distiller
+        num_bits (int): Number of bits used for quantization in Distiller
+        distiller_mode (distiller.quantization.LinearQuantMode): The quantization mode used in Distiller
+        dest_dtype (torch.dtype): PyTorch quantized dtype to convert to. Must be one of: torch.quint8, torch.qint8
+        reduce_range (bool): Reduces the range of the quantized data type by 1 bit. This should mainly be used for
+          quantized activations with the "fbgemm" PyTorch backend - it prevents overflows. See:
+          https://github.com/pytorch/pytorch/blob/fde94e75568b527b424b108c272793e096e8e471/torch/quantization/observer.py#L294
+
+    Returns:
+        Tuple of (scale, zero_point) which are compatible with PyTorch quantization API
+    """
+    assert dest_dtype in (torch.qint8, torch.quint8), 'Must specify one of the quantized PyTorch dtypes'
+
+    distiller_symmetric = is_linear_quant_mode_symmetric(distiller_mode)
+    if distiller_symmetric and dest_dtype == torch.quint8:
+        reduce_range = False
+
+    distiller_asym_signed = distiller_mode == LinearQuantMode.ASYMMETRIC_SIGNED
+
+    if reduce_range:
+        assert num_bits == 8, 'reduce_range needed only when num_bits == 8'
+        if distiller_symmetric and dest_dtype == torch.quint8:
+            raise NotImplementedError('reduce_range + symmetric + quint8 not supported in PyTorch')
+        num_bits = 7
+        if distiller_symmetric:
+            ratio = 63. / 127.
+        else:
+            ratio = 127. / 255.
+            zp_offset = 128 if distiller_asym_signed else 0
+            zp = ((zp - zp_offset) * ratio + zp_offset / 2).round()
+        scale = scale * ratio
+
+    scale = scale.cpu().squeeze()
+    zp = zp.cpu().squeeze().long()
+
+    # Distiller scale is the reciprocal of PyTorch scale
+    scale_torch = 1. / scale
+
+    n_bins_half = 2 ** (num_bits - 1)
+
+    if distiller_symmetric:
+        # In Distiller symmetric is always signed with zero-point = 0, but in PyTorch it can be
+        # unsigned in which case we offset the zero-point to the middle of the quantized range
+        zp_torch = zp if dest_dtype == torch.qint8 else torch.full_like(zp, n_bins_half)
+    else:
+        pytorch_signed = dest_dtype == torch.qint8
+        if distiller_asym_signed and not pytorch_signed:
+            zp = zp - n_bins_half
+        elif not distiller_asym_signed and pytorch_signed:
+            zp = zp + n_bins_half
+        # Distiller subtracts the zero-point when quantizing, PyTorch adds it.
+        # So we negate the zero-point calculated in Distiller
+        zp_torch = -zp
+    return scale_torch, zp_torch
+
+
+def distiller_quantized_tensor_to_pytorch(tensor: torch.Tensor, scale, zp, num_bits, distiller_mode, dest_dtype,
+                                          per_channel=False, channel_dim=0):
+    """
+    Convert a tensor quantized with quantization parameters calculated by Distiller to a PyTorch "native" quantized
+    tensor.
+
+    We refer to quantization parameters calculated using either of:
+      * distiller.quantization.symmetric_linear_quantization_params
+      * distiller.quantization.asymmetric_linear_quantization_params
+
+    And to tensors quantized using either of:
+      * distiller.quantization.linear_quantize
+      * distiller.quantization.linear_quantize_clamp
+
+    Args:
+        tensor (torch.Tensor): The tensor quantized in Distiller
+        scale (torch.Tensor): Scale factor calcualted by Distiller
+        zp (torch.Tensor): Zero point calcualted by Distiller
+        num_bits (int): Number of bits used for quantization in Distiller
+        distiller_mode (distiller.quantization.LinearQuantMode): The quantization mode used in Distiller
+        dest_dtype (torch.dtype): PyTorch quantized dtype to convert to. Must be one of: torch.quint8, torch.qint8
+        per_channel (bool): Flag in indicating if tensor was quantized per-channel
+        channel_dim (int): If per_channel is set, this indicates the dimension of the channel in the tensor
+
+    Returns:
+        PyTorch quantized tensor (dtype one of torch.quint8 / torch.qint8 / torch.qint32)
+    """
+    assert (tensor == tensor.int()).all(), 'Tensor does not appear to be quantized'
+    converted_scale, converted_zp = distiller_qparams_to_pytorch(scale, zp, num_bits, distiller_mode, dest_dtype,
+                                                                 reduce_range=False)
+    zp_diff = -converted_zp.view(zp.shape) - zp
+
+    if dest_dtype == torch.quint8:
+        temp_dtype = torch.uint8
+    elif dest_dtype == torch.qint8:
+        temp_dtype = torch.int8
+    else:  # dest_dtype == torch.qint32:
+        temp_dtype = torch.int32
+    tensor = (tensor - zp_diff).to(temp_dtype)
+    if per_channel:
+        return torch._make_per_channel_quantized_tensor(tensor, converted_scale, converted_zp, channel_dim)
+    return torch._make_per_tensor_quantized_tensor(tensor, converted_scale, converted_zp)
+
+
+def _ptq_convert_pass_replace_range_linear_wrappers(module):
+    # Hacky deferred import for now to workaround circular dependency
+    # TODO: Proper fix
+    from distiller.quantization import RangeLinearQuantWrapper
+
+    reassign = OrderedDict()
+    for n, m in module.named_children():
+        new_m = m
+        if isinstance(m, distiller.quantization.RangeLinearQuantWrapper):
+            new_m = m.to_pytorch_quant(need_reduce_range(m.output_quant_settings.quant_mode, torch.quint8))
+
+            requires_quantized_inputs = not (isinstance(new_m, nn.Sequential) and
+                                             isinstance(new_m[0], ConditionalDeQuantizeWrapper))
+
+            if requires_quantized_inputs:
+                d = OrderedDict()
+                for idx, qmd in m.inputs_quant_metadata_fallback.items():
+                    qset = m.inputs_quant_settings_overrides.get(idx, m.output_quant_settings)
+                    scale, zp = distiller_qparams_to_pytorch(qmd.scale, qmd.zero_point, qset.num_bits,
+                                                             qset.quant_mode, torch.quint8,
+                                                             need_reduce_range(qset.quant_mode, torch.quint8))
+                    d[idx] = (scale, zp, torch.quint8)
+                new_m = ConditionalQuantizeWrapper(new_m, d)
+        elif distiller.has_children(m):
+            new_m = _ptq_convert_pass_replace_range_linear_wrappers(m)
+        elif not isinstance(m, nn.Identity):
+            # Module not quantized in Distiller, possibly need to de-quant input
+            new_m = ConditionalDeQuantizeWrapper(m)
+        reassign[n] = new_m
+
+    for n, new_m in reassign.items():
+        module._modules[n] = new_m
+
+    return module
+
+
+def patch_model_output_dequant(model):
+    def patched_forward(self, input):
+        out = self._original_forward(input)
+        out = self.output_dequant(out)
+        return out
+
+    model.add_module('output_dequant', nnq.DeQuantize())
+    model._original_forward = model.forward
+    # https://stackoverflow.com/questions/972/adding-a-method-to-an-existing-object-instance#comment66379065_2982
+    model.forward = patched_forward.__get__(model)
+
+
+def _ptq_convert_pass_remove_redundant_quant_dequant(model, dummy_input):
+    def quantize_wrapper_check_hook(module, inputs):
+        if not isinstance(module, ConditionalQuantize):
+            return
+        q_inputs = []
+        for idx, t in enumerate(inputs):
+            if not isinstance(t, torch.Tensor):
+                continue
+            if t.is_quantized:
+                q_inputs.append(idx)
+        module.already_quantized = q_inputs
+
+    def dequant_wrapper_check_hook(module, input):
+        if not isinstance(module, ConditionalDeQuantize):
+            return
+        module.any_quantized = False
+
+        def check_recursively(x):
+            if isinstance(x, torch.Tensor) and x.is_quantized:
+                module.any_quantized = True
+            elif isinstance(x, (tuple, list)):
+                for item in x:
+                    check_recursively(item)
+
+        check_recursively(input)
+
+    def cleanup(module):
+        reassign = OrderedDict()
+        for n, m in module.named_children():
+            new_m = m
+            if isinstance(m, ConditionalQuantizeWrapper):
+                for idx in m.quant.already_quantized:
+                    if str(idx) in m.quant.quantizers:
+                        m.quant.quantizers.pop(str(idx))
+                if len(m.quant.quantizers) == 0:
+                    new_m = m.wrapped
+            elif isinstance(m, ConditionalDeQuantizeWrapper):
+                if not m.dequant.any_quantized:
+                    new_m = m.wrapped
+            elif distiller.has_children(m):
+                cleanup(m)
+            reassign[n] = new_m
+        for n, new_m in reassign.items():
+            module._modules[n] = new_m
+
+        return module
+
+    handles = []
+    for m in model.modules():
+        if isinstance(m, ConditionalQuantize):
+            handles.append(m.register_forward_pre_hook(quantize_wrapper_check_hook))
+        elif isinstance(m, ConditionalDeQuantize):
+            handles.append(m.register_forward_pre_hook(dequant_wrapper_check_hook))
+    out = model(dummy_input)
+    for h in handles:
+        h.remove()
+
+    model = cleanup(model)
+
+    if out.is_quantized:
+        patch_model_output_dequant(model)
+
+    return model
+
+
+def convert_distiller_ptq_model_to_pytorch(model, dummy_input, backend='fbgemm'):
+    """
+    Convert a model quantized using distiller.quantization.PostTrainLinearQuantizer to model comprised solely of
+    native PyTorch static post-training quantization modules and operators.
+
+    In the current implementation this conversion CANNOT be done in-place.
+
+    Conversion is done in 2 passes:
+      * First pass: Replace all RangeLinearQuantWrapper modules with a quantize operation followed by the respective
+        native PyTorch module. Modules that weren't quantized by Distiller are wrapped with a de-quantize operation.
+      * Second pass: Perform dummy forward pass over the model and remove redundant de-quant --> quant sequences.
+
+    The converted model returns a de-quantized output. If the last layer of the model is quantized, then an extra
+    dequantize module will be added to the model. This extra module is named 'output_dequant', and the model's
+    forward method is patched to execute this module after the main model.
+    NOTE: This assumes the model produces a single output tensor. In other cases the results are unexpected.
+
+    NOTE: The converted model will be on the CPU, and non-parallel (that is - without nn.DataParallel modules)
+
+    Args:
+        model (torch.nn.Module): The model to be converted
+        dummy_input (torch.nn.Tensor): A tensor in the shape expected by the model, required for the second pass
+          of the conversion
+        backend (str): The PyTorch quantization backend to use. Currently supported values: 'fbgemm', 'qnnpack'
+
+    Returns:
+        The converted model
+    """
+    # Hacky deferred import for now to workaround circular dependency
+    # TODO: Proper fix
+    from distiller.quantization import PostTrainLinearQuantizer
+    if not hasattr(model, 'quantizer_metadata') or model.quantizer_metadata['type'] != PostTrainLinearQuantizer:
+        raise ValueError('Conversion to PyTorch native quantization supported only for models quantized '
+                         'using distiller.quantization.PostTrainLinearQuantizer')
+
+    if dummy_input is None or not isinstance(dummy_input, torch.Tensor):
+        raise ValueError('Valid dummy input tensor required for converting PTQ model to PyTorch')
+
+    backends = ('fbgemm', 'qnnpack')
+    if backend not in backends:
+        raise ValueError('{} is not a supported PyTorch quantization backend. Supported: {}'.format(backend, backends))
+    torch.backends.quantized.engine = backend
+
+    # TODO: Add in-place option. Not totally straight-forward because of the output dequantization
+    #       Can monkey-patch instead of creating a Sequential, then it can really be in-place
+
+    # Save quantizer metadata so we can re-attach it to the model after conversion, which enables loading the
+    # converted model from a checkpoint
+    quantizer_metadata = deepcopy(model.quantizer_metadata)
+    model = distiller.make_non_parallel_copy(model).cpu()
+
+    # First pass
+    model = _ptq_convert_pass_replace_range_linear_wrappers(model)
+
+    # Second pass
+    model = _ptq_convert_pass_remove_redundant_quant_dequant(model, dummy_input)
+
+    # This is used when loading the model from a checkpoint, to indicate that conversion needs to be applied
+    quantizer_metadata['pytorch_convert'] = True
+    model.quantizer_metadata = quantizer_metadata
+
+    return model
+
+
+class QFunctionalWrapper(nn.Module):
+    def __init__(self):
+        super(QFunctionalWrapper, self).__init__()
+        self.qfunc = nnq.QFunctional()
+
+
+class QFunctionalAdd(QFunctionalWrapper):
+    def __init__(self):
+        super(QFunctionalAdd, self).__init__()
+
+    def forward(self, x, y):
+        return self.qfunc.add(x, y)
+
+
+class QFunctionalAddScalar(QFunctionalWrapper):
+    def __init__(self):
+        super(QFunctionalAddScalar, self).__init__()
+
+    def forward(self, x, y):
+        return self.qfunc.add_scalar(x, y)
+
+
+class QFunctionalMul(QFunctionalWrapper):
+    def __init__(self):
+        super(QFunctionalMul, self).__init__()
+
+    def forward(self, x, y):
+        return self.qfunc.mul(x, y)
+
+
+class QFunctionalMulScalar(QFunctionalWrapper):
+    def __init__(self):
+        super(QFunctionalMulScalar, self).__init__()
+
+    def forward(self, x, y):
+        return self.qfunc.mul_scalar(x, y)
+
+
+class QFunctionalCat(QFunctionalWrapper):
+    def __init__(self, dim=0):
+        super(QFunctionalCat, self).__init__()
+        self.dim = dim
+
+    def forward(self, *x):
+        return self.qfunc.cat(x, self.dim)
+
+
+class QFunctionalAddRelu(QFunctionalWrapper):
+    def __init__(self):
+        super(QFunctionalAddRelu, self).__init__()
+
+    def forward(self, x, y):
+        return self.qfunc.add_relu(x, y)
+
+
+class ConditionalDeQuantize(nn.Module):
+    def __init__(self):
+        super(ConditionalDeQuantize, self).__init__()
+
+    def forward(self, *inputs):
+        def dequant_recursively(x):
+            if isinstance(x, torch.Tensor):
+                return x.dequantize() if x.is_quantized else x
+            if isinstance(x, (tuple, list)):
+                return type(x)(dequant_recursively(item) for item in x)
+            return x
+        outputs = dequant_recursively(inputs)
+        return outputs
+
+
+class ConditionalDeQuantizeWrapper(nn.Module):
+    def __init__(self, wrapped_module):
+        super(ConditionalDeQuantizeWrapper, self).__init__()
+        self.dequant = ConditionalDeQuantize()
+        self.wrapped = wrapped_module
+
+    def forward(self, *inputs):
+        out = self.dequant(*inputs)
+        out = self.wrapped(*out)
+        return out
+
+
+class ConditionalQuantize(nn.Module):
+    def __init__(self, inputs_to_qparams_map):
+        super(ConditionalQuantize, self).__init__()
+        self.quantizers = nn.ModuleDict()
+        for idx, qparams in inputs_to_qparams_map.items():
+            self.quantizers[str(idx)] = nnq.Quantize(*qparams)
+
+    def forward(self, *inputs):
+        q_inputs = []
+        for idx, item in enumerate(inputs):
+            idx_str = str(idx)
+            if idx_str in self.quantizers:
+                assert isinstance(item, torch.Tensor), 'Trying to quantize a non-Tensor object'
+                if not item.is_quantized:
+                    item = self.quantizers[idx_str](item)
+            q_inputs.append(item)
+        # return q_inputs[0] if len(q_inputs) == 1 else tuple(q_inputs)
+        return tuple(q_inputs)
+
+
+class ConditionalQuantizeWrapper(nn.Module):
+    def __init__(self, wrapped_module, inputs_to_qparams_map):
+        super(ConditionalQuantizeWrapper, self).__init__()
+        self.quant = ConditionalQuantize(inputs_to_qparams_map)
+        self.wrapped = wrapped_module
+
+    def forward(self, *inputs):
+        out = self.quant(*inputs)
+        out = self.wrapped(*out)
+        return out
diff --git a/distiller/quantization/quantizer.py b/distiller/quantization/quantizer.py
index 1b33530..f4bc448 100644
--- a/distiller/quantization/quantizer.py
+++ b/distiller/quantization/quantizer.py
@@ -190,6 +190,8 @@ class Quantizer(object):
         self.modules_processed = OrderedDict()
         self.modules_processed_args = OrderedDict()
 
+        self.prepared = False
+
     def _add_qbits_entry(self, module_name, module_type, qbits):
         if module_type not in [nn.Conv2d, nn.Conv3d, nn.Linear, nn.Embedding]:
             # For now we support weights quantization only for Conv, FC and Embedding layers (so, for example, we don't
@@ -247,7 +249,7 @@ class Quantizer(object):
                 self.params_to_quantize.append(_ParamToQuant(module, module_name, fp_attr_name, param_name, n_bits))
 
                 param_full_name = '.'.join([module_name, param_name])
-                msglogger.info(
+                msglogger.debug(
                     "Parameter '{0}' will be quantized to {1} bits".format(param_full_name, n_bits))
 
         # If an optimizer was passed, assume we need to update it
@@ -262,7 +264,9 @@ class Quantizer(object):
 
         distiller.assign_layer_fq_names(self.model)
 
-        msglogger.info('Quantized model:\n\n{0}\n'.format(self.model))
+        self.prepared = True
+
+        msglogger.debug('Quantized model:\n\n{0}\n'.format(self.model))
 
     def _pre_prepare_model(self, dummy_input):
         pass
diff --git a/distiller/quantization/range_linear.py b/distiller/quantization/range_linear.py
index 5e3c7f8..ef92314 100644
--- a/distiller/quantization/range_linear.py
+++ b/distiller/quantization/range_linear.py
@@ -33,6 +33,11 @@ from .q_utils import *
 from .sim_bn_fold import SimulatedFoldedBatchNorm
 import distiller.modules
 import distiller.model_transforms as mt
+from . import pytorch_quant_conversion as pytqc
+
+import torch.quantization
+import torch.nn.quantized as nnq
+import torch.nn.intrinsic.quantized as nniq
 
 msglogger = logging.getLogger()
 
@@ -326,6 +331,10 @@ def add_post_train_quant_args(argparser):
                             'this number of bits the integer multiplier')
     group.add_argument('--qe-save-fp-weights', action='store_true',
                        help='Allow weights requantization.')
+    group.add_argument('--qe-convert-pytorch', '--qept', action='store_true',
+                       help='Convert the model to PyTorch native post-train quantization modules')
+    group.add_argument('--qe-pytorch-backend', default='fbgemm', choices=['fbgemm', 'qnnpack'],
+                       help='When --qe-convert-pytorch is set, specifies the PyTorch quantization backend to use')
 
     stats_group = group.add_mutually_exclusive_group()
     stats_group.add_argument('--qe-stats-file', type=str, metavar='PATH',
@@ -657,6 +666,15 @@ class RangeLinearQuantWrapper(nn.Module):
         """
         raise NotImplementedError
 
+    def to_pytorch_quant(self, reduce_range):
+        assert self.output_quant_settings.num_bits == 8, \
+            'Conversion to PyTorch PTQ supported only for 8-bit quantization'
+        assert self.preset_act_stats, 'Conversion to PyTorch PTQ supported only for PTQ wrappers with activation stats'
+        return self._convert_to_pytorch_quant(reduce_range)
+
+    def _convert_to_pytorch_quant(self, reduce_range):
+        raise NotImplementedError
+
     def extra_repr(self):
         if self.output_quant_settings.num_bits is None:
             return 'output_quant_settings=Not_Quantized'
@@ -952,6 +970,63 @@ class RangeLinearQuantParamLayerWrapper(RangeLinearQuantWrapper):
             requant_scale = approx_scale_as_mult_and_shift(requant_scale, self.scale_approx_mult_bits)
         return requant_scale, output_zero_point
 
+    def _convert_to_pytorch_quant(self, reduce_range):
+        wrapped = self.wrapped_module
+        supported = (nn.Conv2d, nn.Linear)
+        # Tuple of module type and flag for relu fusing
+        mapping = {
+            (nn.Linear, False): nnq.Linear,
+            (nn.Linear, True): nniq.LinearReLU,
+            (nn.Conv2d, False): nnq.Conv2d,
+            (nn.Conv2d, True): nniq.ConvReLU2d
+        }
+        if nn.Conv3d in torch.quantization.DEFAULT_MODULE_MAPPING:
+            # Conv3D supported only from PyTorch 1.4
+            supported += nn.Conv3d,
+            mapping.update({
+                (nn.Conv3d, False): nnq.Conv3d,
+                (nn.Conv3d, True): nniq.ConvReLU3d,
+            })
+        assert isinstance(wrapped, supported), \
+            'Conversion to PyTorch PTQ supported only for {}'.format(','.join(supported))
+        assert self.wts_quant_settings.num_bits == 8, 'Conversion to PyTorch PTQ supported only for 8-bit quantization'
+
+        # Convert weights - required by PyTorch to be signed 8-bit (torch.qint8)
+        q_weight = pytqc.distiller_quantized_tensor_to_pytorch(wrapped.weight.clone().detach(),
+                                                               self.w_scale, self.w_zero_point,
+                                                               self.wts_quant_settings.num_bits,
+                                                               self.wts_quant_settings.quant_mode, torch.qint8,
+                                                               self.wts_quant_settings.per_channel, 0)
+
+        # PyTorch PTQ modules expect the bias in FP32, we need to dequantize if necessary
+        # With Distiller PTQ the bias is only quantized on the first forward - we do a crude check if it has
+        # been quantized or not
+        fp_bias = wrapped.bias.clone().detach()
+        if self.has_bias:
+            bias_quantized = (fp_bias == fp_bias.int()).all()
+            if bias_quantized:
+                fp_bias = linear_dequantize(fp_bias, self.accum_scale.squeeze(), 0, True)
+
+        pytorch_cls = mapping[(type(wrapped), self.clip_half_range)]
+        if isinstance(wrapped, nn.Linear):
+            pytorch_module = pytorch_cls(wrapped.in_features, wrapped.out_features, wrapped.bias is not None)
+        else:
+            pytorch_module = pytorch_cls(wrapped.in_channels, wrapped.out_channels, wrapped.kernel_size,
+                                         wrapped.stride, wrapped.padding, wrapped.dilation, wrapped.groups,
+                                         wrapped.bias is not None, wrapped.padding_mode)
+
+        pytorch_module.set_weight_bias(q_weight, fp_bias)
+
+        # Convert activations qparams - required by PyTorch to be unsigned 8-bit (torch.quint8)
+        out_scale, out_zp = pytqc.distiller_qparams_to_pytorch(self.output_scale, self.output_zero_point,
+                                                               self.output_quant_settings.num_bits,
+                                                               self.output_quant_settings.quant_mode, torch.quint8,
+                                                               reduce_range)
+        pytorch_module.scale = float(out_scale)
+        pytorch_module.zero_point = int(out_zp)
+
+        return pytorch_module
+
     def extra_repr(self):
         tmpstr = 'weights_quant_settings={0}\n'.format(self.wts_quant_settings)
         tmpstr += super(RangeLinearQuantParamLayerWrapper, self).extra_repr()
@@ -1026,6 +1101,19 @@ class RangeLinearQuantMatmulWrapper(RangeLinearQuantWrapper):
             requant_scale = approx_scale_as_mult_and_shift(requant_scale, self.scale_approx_mult_bits)
         return requant_scale, output_zero_point
 
+    def _convert_to_pytorch_quant(self, reduce_range):
+        # Convert activations qparams - required by PyTorch to be unsigned 8-bit (torch.quint8)
+        scale, zp = pytqc.distiller_qparams_to_pytorch(self.output_scale, self.output_zero_point,
+                                                       self.output_quant_settings.num_bits,
+                                                       self.output_quant_settings.quant_mode, torch.quint8,
+                                                       reduce_range)
+        modules = [self.wrapped_module, nnq.Quantize(float(scale), int(zp), torch.quint8)]
+        if self.clip_half_range:
+            # The scale factor calculated in Distiller already considers the ReLU, so it's OK to apply the
+            # ReLU after quantization
+            modules.append(nnq.ReLU())
+        return modules
+
 
 class NoStatsError(NotImplementedError):
     pass
@@ -1065,6 +1153,21 @@ class RangeLinearQuantConcatWrapper(RangeLinearQuantWrapper):
         # Nothing to do here, since we already re-quantized in quantized_forward prior to the actual concatenation
         return 1., self.output_zero_point
 
+    def _convert_to_pytorch_quant(self, reduce_range):
+        # Convert activations qparams - required by PyTorch to be unsigned 8-bit (torch.quint8)
+        scale, zp = pytqc.distiller_qparams_to_pytorch(self.output_scale, self.output_zero_point,
+                                                       self.output_quant_settings.num_bits,
+                                                       self.output_quant_settings.quant_mode, torch.quint8,
+                                                       reduce_range)
+        m = pytqc.QFunctionalCat(self.wrapped_module.dim)
+        m.qfunc.scale = float(scale)
+        m.qfunc.zp = int(zp)
+        if self.clip_half_range:
+            # The scale factor calculated in Distiller already considers the ReLU, so it's OK to apply the
+            # ReLU after quantization
+            m = nn.Sequential(m, nnq.ReLU())
+        return m
+
 
 class RangeLinearQuantEltwiseAddWrapper(RangeLinearQuantWrapper):
     def __init__(self, wrapped_module, num_bits_acts, mode=LinearQuantMode.SYMMETRIC, clip_acts=ClipMode.NONE,
@@ -1101,6 +1204,17 @@ class RangeLinearQuantEltwiseAddWrapper(RangeLinearQuantWrapper):
     def get_accum_to_output_re_quantization_params(self, output_scale, output_zero_point):
         return 1., self.output_zero_point
 
+    def _convert_to_pytorch_quant(self, reduce_range):
+        # Convert activations qparams - required by PyTorch to be unsigned 8-bit (torch.quint8)
+        scale, zp = pytqc.distiller_qparams_to_pytorch(self.output_scale, self.output_zero_point,
+                                                       self.output_quant_settings.num_bits,
+                                                       self.output_quant_settings.quant_mode, torch.quint8,
+                                                       reduce_range)
+        m = pytqc.QFunctionalAddRelu() if self.clip_half_range else pytqc.QFunctionalAdd()
+        m.qfunc.scale = float(scale)
+        m.qfunc.zp = int(zp)
+        return m
+
 
 class RangeLinearQuantEltwiseMultWrapper(RangeLinearQuantWrapper):
     def __init__(self, wrapped_module, num_bits_acts, mode=LinearQuantMode.SYMMETRIC, clip_acts=ClipMode.NONE,
@@ -1142,6 +1256,21 @@ class RangeLinearQuantEltwiseMultWrapper(RangeLinearQuantWrapper):
             requant_scale = approx_scale_as_mult_and_shift(requant_scale, self.scale_approx_mult_bits)
         return requant_scale, output_zero_point
 
+    def _convert_to_pytorch_quant(self, reduce_range):
+        # Convert activations qparams - requirec by PyTorch to be unsigned 8-bit (torch.quint8)
+        scale, zp = pytqc.distiller_qparams_to_pytorch(self.output_scale, self.output_zero_point,
+                                                       self.output_quant_settings.num_bits,
+                                                       self.output_quant_settings.quant_mode, torch.quint8,
+                                                       reduce_range)
+        m = pytqc.QFunctionalMul()
+        m.qfunc.scale = float(scale)
+        m.qfunc.zp = int(zp)
+        if self.clip_half_range:
+            # The scale factor calculated in Distiller already considers the ReLU, so it's OK to apply the
+            # ReLU after quantization
+            m = nn.Sequential(m, nnq.ReLU())
+        return m
+
 
 class FPWrapper(nn.Module):
     """
@@ -1342,6 +1471,34 @@ class RangeLinearFakeQuantWrapper(RangeLinearQuantWrapper):
     def get_accum_to_output_re_quantization_params(self, output_scale, output_zero_point):
         return output_scale, output_zero_point
 
+    def _convert_to_pytorch_quant(self, reduce_range):
+        # A few PyTorch modules support quantized inputs/outputs
+        supported = {
+            nn.ReLU: nnq.ReLU(),
+            nn.ReLU6: nnq.ReLU6(),
+            nn.AvgPool2d: self.wrapped_module,
+            nn.AdaptiveAvgPool2d: self.wrapped_module,
+            nn.MaxPool2d: self.wrapped_module
+        }
+        q_module = supported.get(type(self.wrapped_module), None)
+        if q_module is None:
+            # No PyTorch quantized module - so fake it
+            # Convert activations qparams - required by PyTorch to be unsigned 8-bit (torch.quint8)
+            scale, zp = pytqc.distiller_qparams_to_pytorch(self.output_scale, self.output_zero_point,
+                                                           self.output_quant_settings.num_bits,
+                                                           self.output_quant_settings.quant_mode, torch.quint8,
+                                                           reduce_range)
+            modules = [pytqc.ConditionalDeQuantizeWrapper(self.wrapped_module),
+                       nnq.Quantize(float(scale), int(zp), torch.quint8)]
+        else:
+            modules = [self.wrapped_module]
+        if self.clip_half_range:
+            # The scale factor calculated in Distiller already considers the ReLU, so it's OK to apply the
+            # ReLU after quantization
+            modules.append(nnq.ReLU())
+
+        return modules[0] if len(modules) == 1 else nn.Sequential(*modules)
+
     def extra_repr(self):
         tmpstr = super(RangeLinearFakeQuantWrapper, self).extra_repr()
         if self.dtype:
@@ -1959,6 +2116,19 @@ class PostTrainLinearQuantizer(Quantizer):
             buffer.data = buffer.data.to(device)
         self.linear_quant_params = OrderedDict(self.named_linear_quant_params())
 
+    def convert_to_pytorch(self, dummy_input, backend='fbgemm'):
+        """
+        Convert a model quantized using distiller.quantization.PostTrainLinearQuantizer to model comprised solely of
+        native PyTorch static post-training quantization modules and operators.
+
+        This is a convenience wrapper around distiller.quantization.convert_distiller_ptq_model_to_pytorch
+        See that function's documentation for more details.
+        """
+        if not self.prepared:
+            raise RuntimeError("Must call prepare_model before attempting to convert to PyTorch")
+
+        return pytqc.convert_distiller_ptq_model_to_pytorch(self.model, dummy_input, backend=backend)
+
 
 ###############################################################################
 # Quantization-aware training
diff --git a/examples/quantization/post_train_quant/command_line.md b/examples/quantization/post_train_quant/command_line.md
index b168b52..ce2c948 100644
--- a/examples/quantization/post_train_quant/command_line.md
+++ b/examples/quantization/post_train_quant/command_line.md
@@ -24,13 +24,25 @@ Post-training quantization can either be configured straight from the command-li
 | `--qe-clip-acts`         | `--qeca`  | Set activations clipping mode. Choices: "none", "avg", "n_std", "gauss", "laplace"    | "none"  |
 | `--qe-clip-n-stds`       | N/A       | When qe-clip-acts is set to 'n_std', this is the number of standard deviations to use | None    |
 | `--qe-no-clip-layers`    | `--qencl` | List of layer names (space-separated) for which not to clip activations               | ''      |
+| `--qe-no-quant-layers`   | `--qenql` | List of layer names (space-separated) for which not to skip quantization              | ''      |
 | `--qe-per-channel`       | `--qepc`  | Enable per-channel quantization of weights (per output channel)                       | Off     |
 | `--qe-scale-approx-bits` | `--qesab` | Enables scale factor approximation using integer multiply + bit shift, using this number of bits the integer multiplier | None |
 | `--qe-stats-file`        | N/A       | Use stats file for static quantization of activations. See details below              | None    |
 | `--qe-dynamic`           | N/A       | Perform dynamic quantization. See details below                                       | None    |
 | `--qe-config-file`       | N/A       | Path to YAML config file. See section above. (ignores all other --qe* arguments)      | None    |
+| `--qe-convert-pytorch`   | `--qept`  | Convert the model to PyTorch native post-train quantization modules                   | Off     |
+| `--qe-pytorch-backend`   | N/A       | When --qe-convert-pytorch is set, specifies the PyTorch quantization backend to use. Choices: "fbgemm", "qnnpack"   | Off     |
 
-(Note that these arguments can be added to any `argparse.ArgumentParser` by calling `distiller.quantization.add_post_train_quant_args()` and passing an existing parser)
+### Notes
+
+1. These arguments can be added to any `argparse.ArgumentParser` by calling `distiller.quantization.add_post_train_quant_args()` and passing an existing parser. This is provided as a convenience only. If you are writing a script and adding these arguments, it is up to you to implement the actual functionality implied by them.
+2. The `--qe-convert-pytorch` works in two settings:
+    * `--quantize-eval` is also set, in which case an FP32 model is first quantized using Distiller's post-training quantization flow, and then converted to a PyTorch native quantization model.
+    * `--quantize-eval` is not set, but a previously post-train quantized model is loaded via `--resume`. In this case, the loaded model is converted to PyTorch native quantization.
+
+### Conversion to PyTorch Built-in Quantization Model
+
+PyTorch released built-in support for quantization in version 1.3. Currently Distiller's quantization functionality is still completely separate from PyTorch's. We provide the ability to take a model which was post-train quantized with Distiller, and is comprised of `RangeLinearQuantWrapper`
 
 ## "Net-Aware" Quantization
 
diff --git a/jupyter/post_train_quant_convert_pytorch.ipynb b/jupyter/post_train_quant_convert_pytorch.ipynb
new file mode 100644
index 0000000..74e4d0a
--- /dev/null
+++ b/jupyter/post_train_quant_convert_pytorch.ipynb
@@ -0,0 +1,481 @@
+{
+ "cells": [
+  {
+   "cell_type": "markdown",
+   "metadata": {},
+   "source": [
+    "# Convert Distiller Post-Train Quantization Models to \"Native\" PyTorch\n",
+    "\n",
+    "## Background\n",
+    "\n",
+    "As of version 1.3 PyTorch comes with built-in quantization functionality. Details are available [here](https://pytorch.org/docs/stable/quantization.html). Distiller's and PyTorch's implementations are completely unrelated. An advantage of PyTorch built-in quantization is that it offers optimized 8-bit execution on CPU and export to GLOW. PyTorch doesn't offer optimized 8-bit execution on GPU (as of version 1.4).\n",
+    "\n",
+    "At the moment we are still keeping Distiller's separate API and implementation, but we've added the capability to convert a **post-training quantization** model created in Distiller to a \"Distiller-free\" model, comprised entirely of PyTorch built-in quantized modules.\n",
+    "\n",
+    "Distiller's quantized layers are actually simulated in FP32. Hence, comparing a Distiller model running on CPU to a PyTorch built-in model, the latter will be significantly faster on CPU. However, a Distiller model on a GPU is still likely to be faster compared to a PyTorch model on CPU. So experimenting with Distiller and converting to PyTorch in the end could be useful. Milage may vary of course, depending on the actual HW setup.\n",
+    "\n",
+    "Let's see how the conversion works."
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "import torch\n",
+    "import matplotlib.pyplot as plt\n",
+    "import os\n",
+    "import math\n",
+    "import torchnet as tnt\n",
+    "from ipywidgets import widgets, interact\n",
+    "from copy import deepcopy\n",
+    "from collections import OrderedDict\n",
+    "\n",
+    "import distiller\n",
+    "from distiller.models import create_model\n",
+    "import distiller.quantization as quant\n",
+    "\n",
+    "# Load some common code and configure logging\n",
+    "# We do this so we can see the logging output coming from\n",
+    "# Distiller function calls\n",
+    "%run './distiller_jupyter_helpers.ipynb'\n",
+    "msglogger = config_notebooks_logger()"
+   ]
+  },
+  {
+   "cell_type": "markdown",
+   "metadata": {},
+   "source": [
+    "## Create Model"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "# By default, the model is moved to the GPU and parallelized (wrapped with torch.nn.DataParallel)\n",
+    "# If no GPU is available, a non-parallel model is created on the CPU\n",
+    "model = create_model(pretrained=True, dataset='imagenet', arch='resnet18', parallel=True)"
+   ]
+  },
+  {
+   "cell_type": "markdown",
+   "metadata": {},
+   "source": [
+    "## Create Data Loaders\n",
+    "\n",
+    "We create separate data loaders for GPU and CPU. Set `batch_size` and `num_workers` to optimal values that match your HW setup.\n",
+    "\n",
+    "(Note we reset the seed before creating each data loader, to make sure both loaders consist of the same subset of the test set)"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "# We use Distiller's built-in data loading functionality for ImageNet\n",
+    "\n",
+    "distiller.set_seed(0)\n",
+    "\n",
+    "subset_size = 1.0 # To save time, can set to value < 1.0\n",
+    "dataset = 'imagenet'\n",
+    "dataset_path = os.path.expanduser('/data2/datasets/imagenet')\n",
+    "\n",
+    "batch_size_gpu = 256\n",
+    "num_workers_gpu = 10\n",
+    "_, _, test_loader_gpu, _ = distiller.apputils.load_data(\n",
+    "    dataset, dataset_path, batch_size_gpu, num_workers_gpu,\n",
+    "    effective_test_size=subset_size, fixed_subset=True, test_only=True)"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "distiller.set_seed(0)\n",
+    "batch_size_cpu = 44\n",
+    "num_workers_cpu = 10\n",
+    "_, _, test_loader_cpu, _ = distiller.apputils.load_data(\n",
+    "    dataset, dataset_path, batch_size_cpu, num_workers_cpu,\n",
+    "    effective_test_size=subset_size, fixed_subset=True, test_only=True)"
+   ]
+  },
+  {
+   "cell_type": "markdown",
+   "metadata": {},
+   "source": [
+    "## Define Evaluation Function"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "def eval_model(data_loader, model, device, print_freq=10):\n",
+    "    print('Evaluating model')\n",
+    "    criterion = torch.nn.CrossEntropyLoss().to(device)\n",
+    "    \n",
+    "    loss = tnt.meter.AverageValueMeter()\n",
+    "    classerr = tnt.meter.ClassErrorMeter(accuracy=True, topk=(1, 5))\n",
+    "\n",
+    "    total_samples = len(data_loader.sampler)\n",
+    "    batch_size = data_loader.batch_size\n",
+    "    total_steps = math.ceil(total_samples / batch_size)\n",
+    "    print('{0} samples ({1} per mini-batch)'.format(total_samples, batch_size))\n",
+    "\n",
+    "    # Switch to evaluation mode\n",
+    "    model.eval()\n",
+    "\n",
+    "    for step, (inputs, target) in enumerate(data_loader):\n",
+    "        with torch.no_grad():\n",
+    "            inputs, target = inputs.to(device), target.to(device)\n",
+    "            # compute output from model\n",
+    "            output = model(inputs)\n",
+    "\n",
+    "            # compute loss and measure accuracy\n",
+    "            loss.add(criterion(output, target).item())\n",
+    "            classerr.add(output.data, target)\n",
+    "            \n",
+    "            if (step + 1) % print_freq == 0:\n",
+    "                print('[{:3d}/{:3d}] Top1: {:.3f}  Top5: {:.3f}  Loss: {:.3f}'.format(\n",
+    "                      step + 1, total_steps, classerr.value(1), classerr.value(5), loss.mean), flush=True)\n",
+    "    print('----------')\n",
+    "    print('Overall ==> Top1: {:.3f}  Top5: {:.3f}  Loss: {:.3f}'.format(\n",
+    "        classerr.value(1), classerr.value(5), loss.mean), flush=True)"
+   ]
+  },
+  {
+   "cell_type": "markdown",
+   "metadata": {},
+   "source": [
+    "## Post-Train Quantize with Distiller"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "quant_mode = {'activations': 'ASYMMETRIC_UNSIGNED', 'weights': 'SYMMETRIC'}\n",
+    "stats_file = \"../examples/quantization/post_train_quant/stats/resnet18_quant_stats.yaml\"\n",
+    "dummy_input = distiller.get_dummy_input(input_shape=model.input_shape)\n",
+    "\n",
+    "quantizer = quant.PostTrainLinearQuantizer(\n",
+    "    deepcopy(model), bits_activations=8, bits_parameters=8, mode=quant_mode,\n",
+    "    model_activation_stats=stats_file, overrides=None\n",
+    ")\n",
+    "quantizer.prepare_model(dummy_input)"
+   ]
+  },
+  {
+   "cell_type": "markdown",
+   "metadata": {},
+   "source": [
+    "## Convert to PyTorch Built-In"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "# Here we trigger the conversion via the Quantizer instance. Later on we show another way which does not\n",
+    "# require the quantizer\n",
+    "pyt_model = quantizer.convert_to_pytorch(dummy_input)\n",
+    "\n",
+    "# Note that the converted model is automatically moved to the CPU, regardless\n",
+    "# of the device of the Distiller model\n",
+    "print('Distiller model device:', distiller.model_device(quantizer.model))\n",
+    "print('PyTorch model device:', distiller.model_device(pyt_model))"
+   ]
+  },
+  {
+   "cell_type": "markdown",
+   "metadata": {},
+   "source": [
+    "## Run Evaluation\n",
+    "### Distiller Model on GPU (if available)"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "if torch.cuda.is_available():\n",
+    "    %time eval_model(test_loader_gpu, quantizer.model, 'cuda')"
+   ]
+  },
+  {
+   "cell_type": "markdown",
+   "metadata": {},
+   "source": [
+    "### Distiller Model on CPU"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "if torch.cuda.is_available():\n",
+    "    print('Creating CPU copy of Distiller model')\n",
+    "    cpu_model = distiller.make_non_parallel_copy(quantizer.model).cpu()\n",
+    "else:\n",
+    "    cpu_model = quantizer.model\n",
+    "%time eval_model(test_loader_cpu, cpu_model, 'cpu', print_freq=60)"
+   ]
+  },
+  {
+   "cell_type": "markdown",
+   "metadata": {},
+   "source": [
+    "### PyTorch model in CPU\n",
+    "\n",
+    "We expect the PyTorch model on CPU to be much faster than the Distiller model on CPU"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "%time eval_model(test_loader_cpu, pyt_model, 'cpu', print_freq=60)"
+   ]
+  },
+  {
+   "cell_type": "markdown",
+   "metadata": {},
+   "source": [
+    "## For the Extra-Curious: Comparing the Models"
+   ]
+  },
+  {
+   "cell_type": "markdown",
+   "metadata": {},
+   "source": [
+    "1. Distiller takes care of quantizing the inputs within the quantized modules PyTorch quantized modules assume the input is already quantized. Hence, for cases where a module's input is not quantized, we explicitly add a quantization operation for the input. The first layer in the model, `conv1` in ResNet18, is such a case\n",
+    "2. Both Distiller and native PyTorch support fused ReLU. In Distiller, this is somewhat obscurely indicated by the `clip_half_range` attribute inside `output_quant_settings`. In PyTorch, the module type is explicitly `QuantizedConvReLU2d`."
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "print('conv1\\n')\n",
+    "print('DISTILLER:\\n{}\\n'.format(quantizer.model.module.conv1))\n",
+    "print('PyTorch:\\n{}\\n'.format(pyt_model.conv1))"
+   ]
+  },
+  {
+   "cell_type": "markdown",
+   "metadata": {},
+   "source": [
+    "Example of internal layers which don't require explicit input quantization:"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "print('layer1.0.conv1')\n",
+    "print(pyt_model.layer1[0].conv1)\n",
+    "print('\\nlayer1.0.add')\n",
+    "print(pyt_model.layer1[0].add)"
+   ]
+  },
+  {
+   "cell_type": "markdown",
+   "metadata": {},
+   "source": [
+    "### Automatic de-quantization <--> quantization in the model\n",
+    "\n",
+    "For each quantized module in the Distiller implementation, we quantize the input and de-quantize the output.\n",
+    "So, if the user explicitly sets \"internal\" modules to run in FP32, this is transparent to the other quantized modules (at the cost of redundant quant-dequant operations).\n",
+    "\n",
+    "When converting to PyTorch we remove these redundant operations, and keep just the required ones in case the user explicitly decided to run some modules in FP32.\n",
+    "\n",
+    "For an example, consider a ResNet \"basic block\" with a residual connection that contains a downsampling convolution. Let's see how such a block looks in our fully-quantized, converted model:"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "print(pyt_model.layer2[0])"
+   ]
+  },
+  {
+   "cell_type": "markdown",
+   "metadata": {},
+   "source": [
+    "We can see all layers are either built-in quantized PyTorch modules, or identity operations representing fused operations. The entire block is quantized, so we don't see any quant-dequnt operations in the middle.\n",
+    "\n",
+    "Now let's create a new quantized model, and this time leave the 'downsample' module in FP32:"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "overrides = OrderedDict(\n",
+    "    [('layer2.0.downsample.0', OrderedDict([('bits_activations', None), ('bits_weights', None)]))]\n",
+    ")\n",
+    "new_quantizer = quant.PostTrainLinearQuantizer(\n",
+    "    deepcopy(model), bits_activations=8, bits_parameters=8, mode=quant_mode,\n",
+    "    model_activation_stats=stats_file, overrides=overrides\n",
+    ")\n",
+    "new_quantizer.prepare_model(dummy_input)\n",
+    "\n",
+    "new_pyt_model = new_quantizer.convert_to_pytorch(dummy_input)\n",
+    "\n",
+    "print(new_pyt_model.layer2[0])"
+   ]
+  },
+  {
+   "cell_type": "markdown",
+   "metadata": {},
+   "source": [
+    "We can see a few differences:\n",
+    "1. The `downsample` module now contains a de-quantize op before the actual convolution\n",
+    "2. The `add` module now contains a quantize op before the actual add. Note that the add operation accepts 2 inputs. In this case the first input (index 0) comes from the `conv2` module, which is quantized. The second input (index 1) comes from the `downsample` module, which we kept in FP32. So, we only need to quantized the input at index 1. We can see this is indeed what is happening, by looking at the `ModuleDict` inside the `quant` module, and noticing it has only a single key for index \"1\".\n",
+    "\n",
+    "Let's see how the `add` module would look if we also kept the `conv2` module in FP32:"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "overrides = OrderedDict(\n",
+    "    [('layer2.0.downsample.0', OrderedDict([('bits_activations', None), ('bits_weights', None)])),\n",
+    "     ('layer2.0.conv2', OrderedDict([('bits_activations', None), ('bits_weights', None)]))]\n",
+    ")\n",
+    "new_quantizer = quant.PostTrainLinearQuantizer(\n",
+    "    deepcopy(model), bits_activations=8, bits_parameters=8, mode=quant_mode,\n",
+    "    model_activation_stats=stats_file, overrides=overrides\n",
+    ")\n",
+    "new_quantizer.prepare_model(dummy_input)\n",
+    "\n",
+    "new_pyt_model = new_quantizer.convert_to_pytorch(dummy_input)\n",
+    "\n",
+    "print(new_pyt_model.layer2[0].add)"
+   ]
+  },
+  {
+   "cell_type": "markdown",
+   "metadata": {},
+   "source": [
+    "We can see that now both inputs to the add module are being quantized."
+   ]
+  },
+  {
+   "cell_type": "markdown",
+   "metadata": {},
+   "source": [
+    "## Another API for Conversion\n",
+    "\n",
+    "In some cases we don't have the actual quantizer. For example - if the Distiller quantized module was loaded from a checkpoint. In those cases we can call a `distiller.quantization` module-level function (In fact, the Quantizer method we used earlier is a wrapper around this function).\n",
+    "\n",
+    "### Save Distiller model to checkpoint"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "# Save Distiller model to checkpoint and load it\n",
+    "distiller.apputils.save_checkpoint(0, 'resnet18', quantizer.model)"
+   ]
+  },
+  {
+   "cell_type": "markdown",
+   "metadata": {},
+   "source": [
+    "### Load Checkpoint\n",
+    "\n",
+    "The model is quantized when the checkpoint is loaded"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "loaded_model = create_model(False, dataset='imagenet', arch='resnet18', parallel=True)\n",
+    "loaded_model = distiller.apputils.load_lean_checkpoint(loaded_model, 'checkpoint.pth.tar')"
+   ]
+  },
+  {
+   "cell_type": "markdown",
+   "metadata": {},
+   "source": [
+    "### Convert and Evaluate"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "# Convert\n",
+    "loaded_pyt_model = distiller.quantization.convert_distiller_ptq_model_to_pytorch(loaded_model, dummy_input)\n",
+    "\n",
+    "# Run evaluation\n",
+    "%time eval_model(test_loader_cpu, loaded_pyt_model, 'cpu', print_freq=60)\n",
+    "\n",
+    "# Cleanup\n",
+    "os.remove('checkpoint.pth.tar')"
+   ]
+  }
+ ],
+ "metadata": {
+  "kernelspec": {
+   "display_name": "Python 3",
+   "language": "python",
+   "name": "python3"
+  },
+  "language_info": {
+   "codemirror_mode": {
+    "name": "ipython",
+    "version": 3
+   },
+   "file_extension": ".py",
+   "mimetype": "text/x-python",
+   "name": "python",
+   "nbconvert_exporter": "python",
+   "pygments_lexer": "ipython3",
+   "version": "3.5.2"
+  }
+ },
+ "nbformat": 4,
+ "nbformat_minor": 2
+}
diff --git a/tests/test_ptq_pytorch_convert.py b/tests/test_ptq_pytorch_convert.py
new file mode 100644
index 0000000..81919cd
--- /dev/null
+++ b/tests/test_ptq_pytorch_convert.py
@@ -0,0 +1,126 @@
+#
+# Copyright (c) 2020 Intel Corporation
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#      http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+import pytest
+
+import torch
+import torch.testing
+
+import distiller.quantization as quantization
+from distiller.quantization.range_linear import _get_quant_params_from_tensor
+
+
+@pytest.fixture(params=[quantization.LinearQuantMode.SYMMETRIC,
+                        quantization.LinearQuantMode.SYMMETRIC_RESTRICTED,
+                        quantization.LinearQuantMode.ASYMMETRIC_UNSIGNED,
+                        quantization.LinearQuantMode.ASYMMETRIC_SIGNED],
+                ids=['symmetric', 'symmetric_restricted', 'asym_unsigned', 'asym_signed'])
+def distiller_mode(request):
+    return request.param
+
+
+@pytest.fixture(params=[torch.qint8, torch.quint8], ids=['torch.qint8', 'torch.quint8'])
+def torch_dtype(request):
+    return request.param
+
+
+@pytest.fixture(params=[False, True], ids=['per_tensor', 'per_channel'])
+def per_channel(request):
+    return request.param
+
+
+@pytest.fixture(params=[False, True], ids=['reduce_off', 'reduce_on'])
+def reduce_range(request):
+    return request.param
+
+
+@pytest.fixture()
+def num_bits():
+    return 8
+
+
+@pytest.fixture()
+def tensor():
+    return torch.randn(64, 256, 7, 7)
+
+
+def test_qparams_conversion(tensor, num_bits, distiller_mode, torch_dtype, per_channel, reduce_range):
+    if reduce_range:
+        if num_bits != 8:
+            return True
+        if quantization.is_linear_quant_mode_symmetric(distiller_mode) and torch_dtype == torch.quint8:
+            return True
+
+    # Calculate quantization parameters with Distiller for number of bits BEFORE reduce_range
+    signed = distiller_mode != quantization.LinearQuantMode.ASYMMETRIC_UNSIGNED
+    distiller_scale, distiller_zp = _get_quant_params_from_tensor(tensor, num_bits, distiller_mode,
+                                                                  per_channel=per_channel)
+
+    # Convert parameters to PyTorch
+    converted_scale, converted_zp = quantization.distiller_qparams_to_pytorch(
+        distiller_scale, distiller_zp, num_bits, distiller_mode, torch_dtype, reduce_range
+    )
+
+    # Quantize tensor with Distiller
+    # If reduce_range is set, then we actually quantize with num_bits-1
+    if reduce_range:
+        num_bits -= 1
+        distiller_scale, distiller_zp = _get_quant_params_from_tensor(tensor, num_bits, distiller_mode,
+                                                                      per_channel=per_channel)
+    restrict = distiller_mode == quantization.LinearQuantMode.SYMMETRIC_RESTRICTED
+    clamp_min, clamp_max = quantization.get_quantized_range(num_bits, signed=signed, signed_restrict_qrange=restrict)
+    distiller_q_t = quantization.linear_quantize_clamp(tensor, distiller_scale, distiller_zp, clamp_min, clamp_max)
+
+    # Quantize with PyTorch
+    if per_channel:
+        pytorch_q_t = torch.quantize_per_channel(tensor, converted_scale, converted_zp, 0, torch_dtype)
+    else:
+        pytorch_q_t = torch.quantize_per_tensor(tensor, converted_scale, converted_zp, torch_dtype)
+
+    # Dequantize
+    distiller_q_dq_t = quantization.linear_dequantize(distiller_q_t, distiller_scale, distiller_zp)
+    pytorch_q_dq_t = pytorch_q_t.dequantize()
+
+    # Compare - allow of up to one quantized "bin" between the tensors
+    if per_channel:
+        for idx, scale in enumerate(converted_scale):
+            torch.testing.assert_allclose(distiller_q_dq_t[idx], pytorch_q_dq_t[idx], atol=scale, rtol=1e-05)
+    else:
+        torch.testing.assert_allclose(pytorch_q_dq_t, distiller_q_dq_t, atol=converted_scale, rtol=1e-05)
+
+
+def test_quantized_tensor_conversion(tensor, num_bits, distiller_mode, torch_dtype, per_channel):
+    # Quantize tensor with Distiller
+    signed = distiller_mode != quantization.LinearQuantMode.ASYMMETRIC_UNSIGNED
+    distiller_scale, distiller_zp = _get_quant_params_from_tensor(tensor, num_bits, distiller_mode,
+                                                                  per_channel=per_channel)
+    restrict = distiller_mode == quantization.LinearQuantMode.SYMMETRIC_RESTRICTED
+    clamp_min, clamp_max = quantization.get_quantized_range(num_bits, signed=signed, signed_restrict_qrange=restrict)
+    distiller_q_t = quantization.linear_quantize_clamp(tensor, distiller_scale, distiller_zp, clamp_min, clamp_max)
+
+    # Convert tensor to PyTorch
+    pytorch_q_t = quantization.distiller_quantized_tensor_to_pytorch(
+        distiller_q_t, distiller_scale, distiller_zp, num_bits, distiller_mode, torch_dtype, per_channel, 0
+    )
+
+    # Dequantize both
+    distiller_q_dq_t = quantization.linear_dequantize(distiller_q_t, distiller_scale, distiller_zp)
+    pytorch_q_dq_t = pytorch_q_t.dequantize()
+
+    # Compare
+    torch.testing.assert_allclose(pytorch_q_dq_t, distiller_q_dq_t)
+
+
+#TODO: Add tests of full model conversion
-- 
GitLab