From e82d938077f9c5abc81b79cfadf245f44fabe0ba Mon Sep 17 00:00:00 2001
From: Guy Jacob <guy.jacob@intel.com>
Date: Mon, 6 Jan 2020 13:38:27 +0200
Subject: [PATCH] Post-train quant: Refactor inputs quantization (#454)

* Fake quant wrapper now also works on (fake) quantized inputs
* Remove 'requires_quantized_inputs' flag
* Unrelated: Moved LinearQuantMode enum to q_utils
---
 distiller/quantization/__init__.py          |   6 +-
 distiller/quantization/ptq_greedy_search.py |   3 +-
 distiller/quantization/q_utils.py           |   6 +
 distiller/quantization/range_linear.py      | 238 ++++++++++----------
 4 files changed, 132 insertions(+), 121 deletions(-)

diff --git a/distiller/quantization/__init__.py b/distiller/quantization/__init__.py
index e24c976..553877f 100644
--- a/distiller/quantization/__init__.py
+++ b/distiller/quantization/__init__.py
@@ -16,9 +16,11 @@
 
 from .quantizer import Quantizer
 from .range_linear import RangeLinearQuantWrapper, RangeLinearQuantParamLayerWrapper, PostTrainLinearQuantizer, \
-    LinearQuantMode, QuantAwareTrainRangeLinearQuantizer, add_post_train_quant_args, NCFQuantAwareTrainQuantizer, \
-    RangeLinearQuantConcatWrapper, RangeLinearQuantEltwiseAddWrapper, RangeLinearQuantEltwiseMultWrapper, ClipMode
+    QuantAwareTrainRangeLinearQuantizer, add_post_train_quant_args, NCFQuantAwareTrainQuantizer, \
+    RangeLinearQuantConcatWrapper, RangeLinearQuantEltwiseAddWrapper, RangeLinearQuantEltwiseMultWrapper, ClipMode, \
+    RangeLinearEmbeddingWrapper, RangeLinearFakeQuantWrapper, RangeLinearQuantMatmulWrapper
 from .clipped_linear import LinearQuantizeSTE, ClippedLinearQuantization, WRPNQuantizer, DorefaQuantizer, PACTQuantizer
+from .q_utils import *
 
 del quantizer
 del range_linear
diff --git a/distiller/quantization/ptq_greedy_search.py b/distiller/quantization/ptq_greedy_search.py
index 5333c4e..46d3b45 100644
--- a/distiller/quantization/ptq_greedy_search.py
+++ b/distiller/quantization/ptq_greedy_search.py
@@ -18,7 +18,8 @@ Here we implement the greedy search algorithm for automatic quantization.
 """
 import torch
 import torch.nn as nn
-from distiller.quantization.range_linear import PostTrainLinearQuantizer, ClipMode, LinearQuantMode
+from distiller.quantization import LinearQuantMode
+from distiller.quantization.range_linear import PostTrainLinearQuantizer, ClipMode
 from distiller.summary_graph import SummaryGraph
 from distiller.model_transforms import fold_batch_norms
 import distiller.modules
diff --git a/distiller/quantization/q_utils.py b/distiller/quantization/q_utils.py
index e1bad3e..8145dc7 100644
--- a/distiller/quantization/q_utils.py
+++ b/distiller/quantization/q_utils.py
@@ -18,6 +18,12 @@ from enum import Enum
 import torch
 
 
+class LinearQuantMode(Enum):
+    SYMMETRIC = 1
+    ASYMMETRIC_UNSIGNED = 2
+    ASYMMETRIC_SIGNED = 3
+
+
 def _prep_saturation_val_tensor(sat_val):
     is_scalar = not isinstance(sat_val, torch.Tensor)
     out = torch.tensor(sat_val) if is_scalar else sat_val.clone().detach()
diff --git a/distiller/quantization/range_linear.py b/distiller/quantization/range_linear.py
index 0765e33..e567700 100644
--- a/distiller/quantization/range_linear.py
+++ b/distiller/quantization/range_linear.py
@@ -52,12 +52,6 @@ def _enum_to_str(enum_val):
     return str(enum_val).split('.')[1]
 
 
-class LinearQuantMode(Enum):
-    SYMMETRIC = 1
-    ASYMMETRIC_UNSIGNED = 2
-    ASYMMETRIC_SIGNED = 3
-
-
 class ModuleQuantMode(namedtuple('ModuleQuantMode', ['activations', 'weights'])):
     """
     Named tuple for configuring the LinearQuantMode of both weights and activations of a module
@@ -227,12 +221,21 @@ class QuantSettings(object):
 
 
 def linear_quantize_clamp_with_metadata(t, inplace=False):
-    return linear_quantize_clamp(t, *t.quant_metadata, inplace)
+    assert hasattr(t, 'quant_metadata')
+    qmd = t.quant_metadata
+    t = linear_quantize_clamp(t, *qmd, inplace)
+    if not inplace:
+        t.quant_metadata = qmd
+    return t
 
 
 def linear_dequantize_with_metadata(t, inplace=False):
+    assert hasattr(t, 'quant_metadata')
     qmd = t.quant_metadata
-    return linear_dequantize(t, qmd.scale, qmd.zero_point, inplace)
+    t = linear_dequantize(t, qmd.scale, qmd.zero_point, inplace)
+    if not inplace:
+        t.quant_metadata = qmd
+    return t
 
 
 def add_post_train_quant_args(argparser):
@@ -331,8 +334,7 @@ class RangeLinearQuantWrapper(nn.Module):
 
     def __init__(self, wrapped_module, num_bits_acts, num_bits_accum=32, mode=LinearQuantMode.SYMMETRIC,
                  clip_acts=ClipMode.NONE, activation_stats=None, clip_n_stds=None, clip_half_range=False,
-                 scale_approx_mult_bits=None,
-                 input_overrides=None, requires_quantized_inputs=True, inputs_quant_auto_fallback=False):
+                 scale_approx_mult_bits=None, input_overrides=None, inputs_quant_auto_fallback=False):
         super(RangeLinearQuantWrapper, self).__init__()
 
         input_overrides = input_overrides or OrderedDict()
@@ -342,7 +344,6 @@ class RangeLinearQuantWrapper(nn.Module):
         self.wrapped_module = wrapped_module
         self.clip_half_range = clip_half_range
         self.scale_approx_mult_bits = scale_approx_mult_bits
-        self.requires_quantized_inputs = requires_quantized_inputs
         self.inputs_quant_auto_fallback = inputs_quant_auto_fallback
 
         self.output_quant_settings = QuantSettings(num_bits_acts, mode.activations, clip_acts, clip_n_stds,
@@ -358,26 +359,25 @@ class RangeLinearQuantWrapper(nn.Module):
             return
 
         # Activations are quantized - setup quantization parameters
-        if self.requires_quantized_inputs:
-            self.inputs_quant_settings_overrides = OrderedDict()
-            for k, v in input_overrides.items():
-                idx = int(k)
-                if v.pop('from_output', None):
-                    quant_settings = deepcopy(self.output_quant_settings)
-                    quant_settings.clip_half_range = False
-                else:
-                    quant_settings = QuantSettings(
-                        v.pop('bits_activations', self.output_quant_settings.num_bits),
-                        verify_quant_mode(v.pop('mode', self.output_quant_settings.quant_mode)),
-                        verify_clip_mode(v.pop('clip_acts', self.output_quant_settings.clip_mode)),
-                        v.pop('clip_n_stds', self.output_quant_settings.clip_n_stds),
-                        False, False)
-                    if v:
-                        # Poor man's input checking on input overrides dict
-                        raise ValueError('Input overrides dict contains unsupported keys:', list(v.keys()))
-                self.inputs_quant_settings_overrides[idx] = quant_settings
-        else:
-            self.inputs_quant_settings_overrides = None
+
+        # Set-up inputs quantization settings
+        self.inputs_quant_settings_overrides = OrderedDict()
+        for k, v in input_overrides.items():
+            idx = int(k)
+            if v.pop('from_output', None):
+                quant_settings = deepcopy(self.output_quant_settings)
+                quant_settings.clip_half_range = False
+            else:
+                quant_settings = QuantSettings(
+                    v.pop('bits_activations', self.output_quant_settings.num_bits),
+                    verify_quant_mode(v.pop('mode', self.output_quant_settings.quant_mode)),
+                    verify_clip_mode(v.pop('clip_acts', self.output_quant_settings.clip_mode)),
+                    v.pop('clip_n_stds', self.output_quant_settings.clip_n_stds),
+                    False, False)
+                if v:
+                    # Poor man's input checking on input overrides dict
+                    raise ValueError('Input overrides dict contains unsupported keys:', list(v.keys()))
+            self.inputs_quant_settings_overrides[idx] = quant_settings
 
         # Controls whether output is de-quantized at end of forward op. Meant as a debug / test flag only
         # (note that if False, the quantized output will be returned, but without any quantization parameters,
@@ -392,21 +392,20 @@ class RangeLinearQuantWrapper(nn.Module):
         if activation_stats:
             self.preset_act_stats = True
 
-            if self.requires_quantized_inputs:
-                self.inputs_quant_metadata_fallback = OrderedDict()
-                for idx, stats in activation_stats['inputs'].items():
-                    settings = self.inputs_quant_settings_overrides.get(idx, self.output_quant_settings)
-                    scale, zp = _get_quant_params_from_stats_dict(
-                        stats, settings.num_bits, settings.quant_mode, settings.clip_mode,
-                        settings.clip_n_stds, settings.clip_half_range, self.scale_approx_mult_bits
-                    )
-                    min_q_val, max_q_val = get_quantized_range(
-                        settings.num_bits, settings.quant_mode != LinearQuantMode.ASYMMETRIC_UNSIGNED)
-                    qmd = TensorQuantMetadata(scale, zp, min_q_val, max_q_val)
-                    self.inputs_quant_metadata_fallback[idx] = qmd
-            else:
-                self.inputs_quant_metadata_fallback = None
-
+            # Calculate inputs quantization parameters
+            self.inputs_quant_metadata_fallback = OrderedDict()
+            for idx, stats in activation_stats['inputs'].items():
+                settings = self.inputs_quant_settings_overrides.get(idx, self.output_quant_settings)
+                scale, zp = _get_quant_params_from_stats_dict(
+                    stats, settings.num_bits, settings.quant_mode, settings.clip_mode,
+                    settings.clip_n_stds, settings.clip_half_range, self.scale_approx_mult_bits
+                )
+                min_q_val, max_q_val = get_quantized_range(
+                    settings.num_bits, settings.quant_mode != LinearQuantMode.ASYMMETRIC_UNSIGNED)
+                qmd = TensorQuantMetadata(scale, zp, min_q_val, max_q_val)
+                self.inputs_quant_metadata_fallback[idx] = qmd
+
+            # Calculate output quantization parameters
             scale, zp = _get_quant_params_from_stats_dict(activation_stats['output'], num_bits_acts, mode.activations,
                                                           clip_acts, clip_n_stds, clip_half_range,
                                                           scale_approx_mult_bits)
@@ -436,19 +435,7 @@ class RangeLinearQuantWrapper(nn.Module):
         for buffer_name, buffer in self._buffers.items():
             setattr(self, buffer_name, buffer.to(device))
 
-        if self.requires_quantized_inputs:
-            self._prepare_inputs_for_quantization(inputs)
-
-            inputs_q = []
-            for input in inputs:
-                qmd = input.quant_metadata
-                input.quant_metadata = TensorQuantMetadata(qmd.scale.to(device), qmd.zero_point.to(device),
-                                                                        qmd.min_q_val, qmd.max_q_val)
-                input_q = linear_quantize_clamp_with_metadata(input)
-                input_q.quant_metadata = input.quant_metadata
-                inputs_q.append(input_q)
-        else:
-            inputs_q = inputs
+        inputs_q = [self._prepare_input(idx, input) for idx, input in enumerate(inputs)]
 
         # Forward through wrapped module
         accum = self.quantized_forward(*inputs_q)
@@ -474,42 +461,50 @@ class RangeLinearQuantWrapper(nn.Module):
 
         return out_f
 
-    def _prepare_inputs_for_quantization(self, inputs):
-        for idx, input in enumerate(inputs):
-            if hasattr(input, 'quant_metadata'):
-                if idx in self.inputs_quant_settings_overrides:
-                    raise RuntimeError('<{}> Input {}: CONFLICT - Tensor has embedded quantization metadata AND user '
-                                       'defined input quantization settings'.format(self.distiller_name, idx))
+    def _prepare_input(self, idx, input):
+        # Default implementation - quantize the input tensor
+        # This works for all but RangeLinearFakeQuantWrapper
+        input.quant_metadata = self._get_input_quant_metadata(idx, input)
+        return linear_quantize_clamp_with_metadata(input, inplace=False)
+
+    def _get_input_quant_metadata(self, idx, input):
+        if hasattr(input, 'quant_metadata'):
+            if idx in self.inputs_quant_settings_overrides:
+                raise RuntimeError('<{}> Input {}: CONFLICT - Tensor has embedded quantization metadata AND user '
+                                   'defined input quantization settings'.format(self.distiller_name, idx))
+            qmd = input.quant_metadata
+        else:
+            # Input doesn't have embedded quantization data propagated from a previous layer
+            # Our options are:
+            #  If user set explicit settings for this input, use those
+            #  OR
+            #  If auto fallback is set, use the output quantization settings
+            if idx not in self.inputs_quant_settings_overrides and not self.inputs_quant_auto_fallback:
+                raise RuntimeError('<{}> Input {}: Expected tensor with embedded quantization metadata. Either:\n'
+                                   '1. Make sure the previous operation is quantized\n'
+                                   '2. Provide explicit input quantization settings\n'
+                                   '3. Set inputs_quant_auto_fallback'.format(self.distiller_name, idx))
+            if self.preset_act_stats:
+                qmd = self.inputs_quant_metadata_fallback[idx]
             else:
-                # Input doesn't have embedded quantization data propagated from a previous layer
-                # Our options are:
-                #  If user set explicit settings for this input, use those
-                #  OR
-                #  If auto fallback is set, use the output quantization settings
-                if idx not in self.inputs_quant_settings_overrides and not self.inputs_quant_auto_fallback:
-                    raise RuntimeError('<{}> Input {}: Expected tensor with embedded quantization metadata. Either:\n'
-                                       '1. Make sure the previous operation is quantized\n'
-                                       '2. Provide explicit input quantization settings\n'
-                                       '3. Set inputs_quant_auto_fallback'.format(self.distiller_name, idx))
-                if self.preset_act_stats:
-                    input.quant_metadata = self.inputs_quant_metadata_fallback[idx]
+                if idx in self.inputs_quant_settings_overrides:
+                    q_settings = self.inputs_quant_settings_overrides[idx]
                 else:
-                    if idx in self.inputs_quant_settings_overrides:
-                        q_settings = self.inputs_quant_settings_overrides[idx]
-                    else:
-                        # If we're here then inputs_quant_auto_fallback is set
-                        # if self.num_forwards == 0:
-                        #     msglogger.info('<{}> Input {}: No embedded quantization metadata, '
-                        #                    'falling back to output settings'.format(self.distiller_name, idx))
-                        q_settings = deepcopy(self.output_quant_settings)
-                        q_settings.clip_half_range = False
-                    scale, zp = _get_quant_params_from_tensor(input, q_settings.num_bits, q_settings.quant_mode,
-                                                              q_settings.clip_mode, q_settings.per_channel,
-                                                              q_settings.clip_n_stds, q_settings.clip_half_range,
-                                                              self.scale_approx_mult_bits)
-                    signed = q_settings.quant_mode != LinearQuantMode.ASYMMETRIC_UNSIGNED
-                    min_q_val, max_q_val = get_quantized_range(q_settings.num_bits, signed)
-                    input.quant_metadata = TensorQuantMetadata(scale, zp, min_q_val, max_q_val)
+                    # If we're here then inputs_quant_auto_fallback is set
+                    q_settings = deepcopy(self.output_quant_settings)
+                    q_settings.clip_half_range = False
+                scale, zp = _get_quant_params_from_tensor(input, q_settings.num_bits, q_settings.quant_mode,
+                                                          q_settings.clip_mode, q_settings.per_channel,
+                                                          q_settings.clip_n_stds, q_settings.clip_half_range,
+                                                          self.scale_approx_mult_bits)
+                signed = q_settings.quant_mode != LinearQuantMode.ASYMMETRIC_UNSIGNED
+                min_q_val, max_q_val = get_quantized_range(q_settings.num_bits, signed)
+                qmd = TensorQuantMetadata(scale, zp, min_q_val, max_q_val)
+
+        # Make sure scale and zp are on correct device
+        qmd = TensorQuantMetadata(qmd.scale.to(input.device), qmd.zero_point.to(input.device),
+                                  qmd.min_q_val, qmd.max_q_val)
+        return qmd
 
     def quantized_forward(self, *inputs_q):
         """
@@ -554,24 +549,21 @@ class RangeLinearQuantWrapper(nn.Module):
 
         tmpstr = 'output_quant_settings={0}'.format(self.output_quant_settings)
         tmpstr += '\naccum_quant_settings={0}'.format(self.accum_quant_settings)
-        tmpstr += '\nrequires_quantized_inputs={0}'.format(self.requires_quantized_inputs)
-        if self.requires_quantized_inputs:
-            overrides = self.inputs_quant_settings_overrides
-            tmpstr += '\n  inputs_quant_auto_fallback={}'.format(self.inputs_quant_auto_fallback)
-            tmpstr += ', forced_quant_settings_for_inputs={}'.format(
-                'None' if not overrides else list(overrides.keys()))
-            for idx, qset in overrides.items():
-                tmpstr += '\n    input_{}_settings={}'.format(idx, qset)
+        overrides = self.inputs_quant_settings_overrides
+        tmpstr += '\n  inputs_quant_auto_fallback={}'.format(self.inputs_quant_auto_fallback)
+        tmpstr += ', forced_quant_settings_for_inputs={}'.format(
+            'None' if not overrides else list(overrides.keys()))
+        for idx, qset in overrides.items():
+            tmpstr += '\n    input_{}_settings={}'.format(idx, qset)
         tmpstr += '\nscale_approx_mult_bits={}'.format(self.scale_approx_mult_bits)
         tmpstr += '\npreset_activation_stats={0}'.format(self.preset_act_stats)
         if self.preset_act_stats:
             tmpstr += '\n  output_scale={0}, output_zero_point={1}'.format(_quant_param_to_str(
                 self.output_scale), _quant_param_to_str(self.output_zero_point))
-            if self.requires_quantized_inputs:
-                for idx in self.inputs_quant_settings_overrides:
-                    qmd = self.inputs_quant_metadata_fallback[idx]
-                    tmpstr += '\n  input_#{0}_scale={1}, input_#{0}_zero_point={2}'.format(
-                        idx, _quant_param_to_str(qmd.scale), _quant_param_to_str(qmd.zero_point))
+            for idx in self.inputs_quant_settings_overrides:
+                qmd = self.inputs_quant_metadata_fallback[idx]
+                tmpstr += '\n  input_#{0}_scale={1}, input_#{0}_zero_point={2}'.format(
+                    idx, _quant_param_to_str(qmd.scale), _quant_param_to_str(qmd.zero_point))
         return tmpstr
 
 
@@ -622,7 +614,6 @@ class RangeLinearQuantParamLayerWrapper(RangeLinearQuantWrapper):
                                                                 clip_acts, activation_stats, clip_n_stds, clip_half_range,
                                                                 scale_approx_mult_bits,
                                                                 input_overrides=input_overrides,
-                                                                requires_quantized_inputs=True,
                                                                 inputs_quant_auto_fallback=inputs_quant_auto_fallback)
 
         if not isinstance(wrapped_module, (nn.Conv2d, nn.Conv3d, nn.Linear)):
@@ -807,7 +798,6 @@ class RangeLinearQuantMatmulWrapper(RangeLinearQuantWrapper):
                                                             clip_acts, activation_stats, clip_n_stds, clip_half_range,
                                                             scale_approx_mult_bits,
                                                             input_overrides=input_overrides,
-                                                            requires_quantized_inputs=True,
                                                             inputs_quant_auto_fallback=inputs_quant_auto_fallback)
 
         if not isinstance(wrapped_module, (distiller.modules.Matmul, distiller.modules.BatchMatmul)):
@@ -859,7 +849,6 @@ class RangeLinearQuantConcatWrapper(RangeLinearQuantWrapper):
                                                             clip_n_stds=clip_n_stds, clip_half_range=clip_half_range,
                                                             scale_approx_mult_bits=scale_approx_mult_bits,
                                                             input_overrides=input_overrides,
-                                                            requires_quantized_inputs=True,
                                                             inputs_quant_auto_fallback=inputs_quant_auto_fallback)
 
     def quantized_forward(self, *inputs_q):
@@ -895,7 +884,6 @@ class RangeLinearQuantEltwiseAddWrapper(RangeLinearQuantWrapper):
                                                                 clip_n_stds=clip_n_stds, clip_half_range=clip_half_range,
                                                                 scale_approx_mult_bits=scale_approx_mult_bits,
                                                                 input_overrides=input_overrides,
-                                                                requires_quantized_inputs=True,
                                                                 inputs_quant_auto_fallback=inputs_quant_auto_fallback)
 
     def quantized_forward(self, *inputs_q):
@@ -932,7 +920,6 @@ class RangeLinearQuantEltwiseMultWrapper(RangeLinearQuantWrapper):
                                                                  clip_n_stds=clip_n_stds, clip_half_range=clip_half_range,
                                                                  scale_approx_mult_bits=scale_approx_mult_bits,
                                                                  input_overrides=input_overrides,
-                                                                 requires_quantized_inputs=True,
                                                                  inputs_quant_auto_fallback=inputs_quant_auto_fallback)
         self.accum_scale = 1
 
@@ -1039,18 +1026,29 @@ class RangeLinearEmbeddingWrapper(nn.Module):
 class RangeLinearFakeQuantWrapper(RangeLinearQuantWrapper):
     def __init__(self, wrapped_module, num_bits_acts, mode=LinearQuantMode.SYMMETRIC, clip_acts=ClipMode.NONE,
                  activation_stats=None, clip_n_stds=None, clip_half_range=False, scale_approx_mult_bits=None,
-                 fpq_module=None):
+                 fpq_module=None, input_overrides=None, inputs_quant_auto_fallback=False):
         super(RangeLinearFakeQuantWrapper, self).__init__(wrapped_module, num_bits_acts, mode=mode,
                                                           clip_acts=clip_acts, activation_stats=activation_stats,
                                                           clip_n_stds=clip_n_stds, clip_half_range=clip_half_range,
                                                           scale_approx_mult_bits=scale_approx_mult_bits,
-                                                          requires_quantized_inputs=False)
+                                                          input_overrides=input_overrides,
+                                                          inputs_quant_auto_fallback=inputs_quant_auto_fallback)
         self.fpq_module = str(fpq_module) if fpq_module else None
         self.dtype = torch.float
         if self.fpq_module:
             self.dtype = {'16': torch.half, '32': torch.float, '64': torch.double}[self.fpq_module]
             self.wrapped_module.to(self.dtype)
 
+    def _prepare_input(self, idx, input):
+        previously_quantized = hasattr(input, 'quant_metadata')
+        input.quant_metadata = self._get_input_quant_metadata(idx, input)
+        if previously_quantized:
+            return input
+
+        # "Fresh" tensor, so need to quantize and the de-quantize (because this is the fake-quant wrapper)
+        input_q = linear_quantize_clamp_with_metadata(input, inplace=False)
+        return linear_dequantize_with_metadata(input_q, inplace=True)
+
     def quantized_forward(self, *inputs_q):
         inputs_q = distiller.convert_tensors_recursively_to(inputs_q, dtype=self.dtype)
         outputs = self.wrapped_module(*inputs_q)
@@ -1214,7 +1212,8 @@ class PostTrainLinearQuantizer(Quantizer):
                                                    activation_stats=activation_stats,
                                                    clip_n_stds=clip_n_stds, clip_half_range=clip_half_range,
                                                    scale_approx_mult_bits=scale_approx_mult_bits,
-                                                   fpq_module=fpq_module)
+                                                   fpq_module=fpq_module, input_overrides=input_overrides,
+                                                   inputs_quant_auto_fallback=inputs_quant_auto_fallback)
 
             return RangeLinearQuantParamLayerWrapper(module, qbits.acts, qbits.wts,
                                                      num_bits_accum=self.bits_accum, mode=mode, clip_acts=clip_acts,
@@ -1244,7 +1243,8 @@ class PostTrainLinearQuantizer(Quantizer):
                                                    activation_stats=activation_stats,
                                                    clip_n_stds=clip_n_stds, clip_half_range=clip_half_range,
                                                    scale_approx_mult_bits=scale_approx_mult_bits,
-                                                   fpq_module=fpq_module)
+                                                   fpq_module=fpq_module, input_overrides=input_overrides,
+                                                   inputs_quant_auto_fallback=inputs_quant_auto_fallback)
             try:
                 return wrapper_type(module, qbits.acts, mode=mode, clip_acts=clip_acts,
                                     activation_stats=activation_stats,
@@ -1267,8 +1267,9 @@ class PostTrainLinearQuantizer(Quantizer):
 
         def replace_fake_quant(module, name, qbits_map, fp16=fp16,
                                clip_acts=clip_acts, clip_n_stds=clip_n_stds, clip_half_range=clip_half_range,
-                               scale_approx_mult_bits=scale_approx_mult_bits, fpq_module=fpq_module, fake=True,
-                               make_identity=False):
+                               scale_approx_mult_bits=scale_approx_mult_bits, input_overrides=None,
+                               inputs_quant_auto_fallback=inputs_quant_auto_fallback,
+                               fpq_module=fpq_module, fake=True, make_identity=False):
             if isinstance(module, (nn.ReLU, nn.ReLU6)) and make_identity:
                 named_modules = OrderedDict(self.model.named_modules())
                 pred = self.adjacency_map[name].predecessors[0].name
@@ -1288,7 +1289,8 @@ class PostTrainLinearQuantizer(Quantizer):
                                                activation_stats=self.model_activation_stats.get(norm_name, None),
                                                clip_n_stds=clip_n_stds,  clip_half_range=clip_half_range,
                                                scale_approx_mult_bits=scale_approx_mult_bits,
-                                               fpq_module=fpq_module)
+                                               fpq_module=fpq_module, input_overrides=input_overrides,
+                                               inputs_quant_auto_fallback=inputs_quant_auto_fallback)
 
         self.clip_acts = clip_acts
         self.clip_n_stds = clip_n_stds
-- 
GitLab