diff --git a/distiller/data_loggers/collector.py b/distiller/data_loggers/collector.py
index 612f7d3f5646c31710a0a83ef7053d87904541b6..7fef7487a45363ad314567d95413a849fe700e41 100755
--- a/distiller/data_loggers/collector.py
+++ b/distiller/data_loggers/collector.py
@@ -82,14 +82,20 @@ class ActivationStatsCollector(object):
         self.model.apply(partial(self._collect_activations_stats, activation_stats=activation_stats))
         return activation_stats
 
-    def start(self):
+    def start(self, modules_list=None):
         """Start collecting activation stats.
 
         This will iteratively register the modules' forward-hooks, so that the collector
         will be called from the forward traversal and get exposed to activation data.
+        modules_list (iterable): track stats for modules in the list. If None/empty - will track for all modules.
         """
         assert len(self.fwd_hook_handles) == 0
-        self.model.apply(self.start_module)
+        if not modules_list:
+            self.model.apply(self.start_module)
+            return
+        modules_dict = dict(self.model.named_modules())
+        for module_name in modules_list:
+            modules_dict[module_name].apply(self.start_module)
 
     def start_module(self, module):
         """Iteratively register to the forward-pass callback of all eligible modules.
@@ -707,7 +713,8 @@ class ActivationHistogramsCollector(ActivationStatsCollector):
 
 
 def collect_quant_stats(model, test_fn, save_dir=None, classes=None, inplace_runtime_check=False,
-                        disable_inplace_attrs=False, inplace_attr_names=('inplace',)):
+                        disable_inplace_attrs=False, inplace_attr_names=('inplace',),
+                        modules_to_collect=None):
     """
     Helper function for collecting quantization calibration statistics for a model using QuantCalibrationStatsCollector
 
@@ -722,6 +729,8 @@ def collect_quant_stats(model, test_fn, save_dir=None, classes=None, inplace_run
         inplace_runtime_check (bool): See QuantCalibrationStatsCollector
         disable_inplace_attrs (bool): See QuantCalibrationStatsCollector
         inplace_attr_names (iterable): See QuantCalibrationStatsCollector
+        modules_to_collect (iterable): enable stats collection for a predefined modules (specified by names).
+          if None - will track stats for all layers.
 
     Returns:
         Dictionary with quantization stats (see QuantCalibrationStatsCollector for a description of the dictionary
@@ -732,7 +741,7 @@ def collect_quant_stats(model, test_fn, save_dir=None, classes=None, inplace_run
                                                            inplace_runtime_check=inplace_runtime_check,
                                                            disable_inplace_attrs=disable_inplace_attrs,
                                                            inplace_attr_names=inplace_attr_names)
-    with collector_context(quant_stats_collector):
+    with collector_context(quant_stats_collector, modules_to_collect):
         msglogger.info('Pass 1: Collecting min, max, avg_min, avg_max, mean, std')
         test_fn(model=model)
         # Collect Laplace distribution stats:
@@ -799,10 +808,10 @@ def collect_histograms(model, test_fn, save_dir=None, activation_stats=None,
 
 
 @contextmanager
-def collector_context(collector):
+def collector_context(collector, modules_list=None):
     """A context manager for an activation collector"""
     if collector is not None:
-        collector.reset().start()
+        collector.reset().start(modules_list)
     yield collector
     if collector is not None:
         collector.stop()
diff --git a/distiller/model_transforms.py b/distiller/model_transforms.py
index dca01f443d490684b6e538a8f095f53873271fbd..5761431bb02a48b50402dc174193f3b655776b0c 100644
--- a/distiller/model_transforms.py
+++ b/distiller/model_transforms.py
@@ -24,7 +24,7 @@ import logging
 msglogger = logging.getLogger()
 
 
-__all__ = ["fuse_modules", "fold_batch_norms_inference"]
+__all__ = ["fuse_modules", "fold_batch_norms"]
 
 
 def fuse_modules(model, types_sequence, fuse_fn, dummy_input=None, adjacency_map=None):
@@ -98,7 +98,7 @@ def fuse_modules(model, types_sequence, fuse_fn, dummy_input=None, adjacency_map
     return model
 
 
-def fold_batch_norms_inference(model, dummy_input=None, adjacency_map=None):
+def fold_batch_norms(model, dummy_input=None, adjacency_map=None, inference=True):
     """Scans the model for convolution / linear modules followed by batch-normalization. For each such valid pair,
     folds the parameters of the batch normalization module into the parameters of the parameter module, and replaces
     the batch normalization module with an identity operation.
@@ -112,6 +112,8 @@ def fold_batch_norms_inference(model, dummy_input=None, adjacency_map=None):
         adjacency_map (OrderedDict): Pre-computed adjacency map, via SummaryGraph.adjacency_map(). Must be based
           on the passed model, otherwise results are unexpected. If None, then the adjacency map will be created
           internally using the passed dummy_input.
+        inference (bool): an indicator on whether or not the modules are in inference mode.
+            This will hard-fuse all BatchNorms into the param-layers.
     """
     def fold_bn(sequence):
         # Re-use this functionality from simulated BN folding implementation
@@ -121,8 +123,10 @@ def fold_batch_norms_inference(model, dummy_input=None, adjacency_map=None):
         except ValueError:
             msglogger.debug("Can't fold, {} does not track running stats".format(bn_module.distiller_name))
             return None
-        folded_module.freeze()
-        return folded_module.param_module
+        if inference:
+            folded_module.freeze()
+            return folded_module.param_module
+        return folded_module
 
     foldables = (nn.Linear, nn.Conv1d, nn.Conv2d, nn.Conv3d)
     batchnorms = (nn.BatchNorm1d, nn.BatchNorm2d, nn.BatchNorm3d)
diff --git a/distiller/models/__init__.py b/distiller/models/__init__.py
index bf628db183fdfbc8187dd95ba543e8c3f3cb0c17..bedc28a5f12c7f1252150636da8aaf1aadd8be2e 100755
--- a/distiller/models/__init__.py
+++ b/distiller/models/__init__.py
@@ -20,12 +20,14 @@ import copy
 from functools import partial
 import torch
 import torchvision.models as torch_models
+import torch.nn as nn
 from . import cifar10 as cifar10_models
 from . import mnist as mnist_models
 from . import imagenet as imagenet_extra_models
 import pretrainedmodels
 
 from distiller.utils import set_model_input_shape_attr
+from distiller.modules import Mean, EltwiseAdd
 
 import logging
 msglogger = logging.getLogger()
@@ -59,16 +61,41 @@ ALL_MODEL_NAMES = sorted(map(lambda s: s.lower(),
                             set(IMAGENET_MODEL_NAMES + CIFAR10_MODEL_NAMES + MNIST_MODEL_NAMES)))
 
 
-# A temporary monkey-patch to get past this Torchvision bug:
-# https://github.com/pytorch/pytorch/issues/20516
-def patch_torchvision_mobilenet_v2_bug(model):
-    def patched_forward(self, x):
+def patch_torchvision_mobilenet_v2(model):
+    """
+    Patches TorchVision's MobileNetV2:
+    * To allow quantization, this adds modules for tensor operations (mean, element-wise addition) to the
+      model instance and patches the forward functions accordingly
+    * Fixes a bug in the torchvision implementation that prevents export to ONNX (and creation of SummaryGraph)
+    """
+    if not isinstance(model, torch_models.MobileNetV2):
+        raise TypeError("Only MobileNetV2 is acceptable.")
+
+    def patched_forward_mobilenet_v2(self, x):
         x = self.features(x)
-        #x = x.mean([2, 3])
-        x = x.mean(3).mean(2)
+        # x = x.mean([2, 3]) # this was a bug: https://github.com/pytorch/pytorch/issues/20516
+        x = self.mean32(x)
         x = self.classifier(x)
         return x
-    model.__class__.forward = patched_forward
+    model.mean32 = nn.Sequential(
+        Mean(3), Mean(2)
+    )
+    model.__class__.forward = patched_forward_mobilenet_v2
+
+    def is_inverted_residual(module):
+        return isinstance(module, nn.Module) and module.__class__.__name__ == 'InvertedResidual'
+
+    def patched_forward_invertedresidual(self, x):
+        if self.use_res_connect:
+            return self.residual_eltwiseadd(self.conv(x), x)
+        else:
+            return self.conv(x)
+
+    for n, m in model.named_modules():
+        if is_inverted_residual(m):
+            if m.use_res_connect:
+                m.residual_eltwiseadd = EltwiseAdd()
+            m.__class__.forward = patched_forward_invertedresidual
 
 
 _model_extensions = {}
@@ -137,7 +164,7 @@ def _create_imagenet_model(arch, pretrained):
         try:
             model = getattr(torch_models, arch)(pretrained=pretrained)
             if arch == "mobilenet_v2":
-                patch_torchvision_mobilenet_v2_bug(model)
+                patch_torchvision_mobilenet_v2(model)
         except NotImplementedError:
             # In torchvision 0.3, trying to download a model that has no
             # pretrained image available will raise NotImplementedError
diff --git a/distiller/models/cifar10/resnet_cifar.py b/distiller/models/cifar10/resnet_cifar.py
index e9ce4e55ab026ba545223d42bce061328f4105d3..ca31731feaa115a21420b80973e6b23d1fd1f09f 100755
--- a/distiller/models/cifar10/resnet_cifar.py
+++ b/distiller/models/cifar10/resnet_cifar.py
@@ -37,6 +37,7 @@ This ResNet also has layer gates, to be able to dynamically remove layers.
 import torch.nn as nn
 import math
 import torch.utils.model_zoo as model_zoo
+from distiller.modules import EltwiseAdd
 
 
 __all__ = ['resnet20_cifar', 'resnet32_cifar', 'resnet44_cifar', 'resnet56_cifar']
@@ -62,6 +63,7 @@ class BasicBlock(nn.Module):
         self.relu2 = nn.ReLU(inplace=False)
         self.downsample = downsample
         self.stride = stride
+        self.residual_eltwiseadd = EltwiseAdd()
 
     def forward(self, x):
         residual = out = x
@@ -78,7 +80,7 @@ class BasicBlock(nn.Module):
         if self.downsample is not None:
             residual = self.downsample(x)
 
-        out += residual
+        out = self.residual_eltwiseadd(residual, out)
         out = self.relu2(out)
 
         return out
diff --git a/distiller/modules/__init__.py b/distiller/modules/__init__.py
index 03c3f57f4898a98cbe3f8649edb86aa0b3bf855b..71c72eef6d8cd75b5c4dc83ca9d4fc5a881bb8ac 100644
--- a/distiller/modules/__init__.py
+++ b/distiller/modules/__init__.py
@@ -18,9 +18,9 @@ from .eltwise import *
 from .grouping import *
 from .matmul import *
 from .rnn import *
-from .aggregate import Norm
+from .aggregate import *
 
 __all__ = ['EltwiseAdd', 'EltwiseMult', 'EltwiseDiv', 'Matmul', 'BatchMatmul',
            'Concat', 'Chunk', 'Split', 'Stack',
            'DistillerLSTMCell', 'DistillerLSTM', 'convert_model_to_distiller_lstm',
-           'Norm']
+           'Norm', 'Mean']
diff --git a/distiller/modules/aggregate.py b/distiller/modules/aggregate.py
index a217627440cb1d929f44a77a37a31de5df55b29a..f6f6cd469d4c215afd120b378d8cb11048050f0f 100644
--- a/distiller/modules/aggregate.py
+++ b/distiller/modules/aggregate.py
@@ -14,3 +14,13 @@ class Norm(nn.Module):
 
     def forward(self, x: torch.Tensor):
         return torch.norm(x, p=self.p, dim=self.dim, keepdim=self.keepdim)
+
+
+class Mean(nn.Module):
+    def __init__(self, *args, **kwargs):
+        super(Mean, self).__init__()
+        self.args = args
+        self.kwargs = kwargs
+
+    def forward(self, x: torch.Tensor):
+        return torch.mean(x, *self.args, **self.kwargs)
diff --git a/distiller/quantization/ptq_greedy_search.py b/distiller/quantization/ptq_greedy_search.py
new file mode 100644
index 0000000000000000000000000000000000000000..1f361febff5580555f179ee951626ff86602dd8d
--- /dev/null
+++ b/distiller/quantization/ptq_greedy_search.py
@@ -0,0 +1,488 @@
+#
+# Copyright (c) 2018 Intel Corporation
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#      http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+r"""
+Here we implement the greedy search algorithm for automatic quantization.
+"""
+import torch
+import torch.nn as nn
+from distiller.quantization.range_linear import PostTrainLinearQuantizer, ClipMode, LinearQuantMode
+from distiller.summary_graph import SummaryGraph
+from distiller.model_transforms import fold_batch_norms
+import distiller.modules
+from distiller.data_loggers import collect_quant_stats
+from distiller.models import create_model
+from collections import OrderedDict, defaultdict
+import logging
+from copy import deepcopy
+import distiller.apputils.image_classifier as classifier
+import os
+import distiller.apputils as apputils
+import re
+import argparse
+
+__all__ = ['ptq_greedy_search']
+
+msglogger = None
+
+QUANTIZED_MODULES = (
+    nn.Linear,
+    nn.Conv2d,
+    nn.Conv3d,
+    distiller.modules.Concat,
+    distiller.modules.EltwiseAdd,
+    distiller.modules.EltwiseMult,
+    distiller.modules.Matmul,
+    distiller.modules.BatchMatmul
+)
+
+FP16_LAYERS = (
+    nn.Tanh,
+    nn.Sigmoid
+)
+
+PARAM_MODULES = (
+    nn.Linear,
+    nn.Conv2d,
+    nn.Conv3d
+)
+
+UNQUANTIZED_MODULES = (
+    nn.Softmax,
+)
+
+SKIP_MODULES = (
+    nn.Identity,
+    nn.BatchNorm1d,
+    nn.BatchNorm2d,
+    nn.BatchNorm3d,
+    nn.ReLU,
+    nn.ReLU6
+)
+
+CLIP_MODES = ['NONE',
+              'AVG',
+              'GAUSS',
+              'LAPLACE'
+              ]
+
+
+def get_default_args():
+    parser = classifier.init_classifier_compression_arg_parser()
+    parser.add_argument('--qe-no-quant-layers', '--qenql', type=str, nargs='+', metavar='LAYER_NAME', default=[],
+                       help='List of layer names for which to skip quantization.')
+    parser.add_argument('--qe-calib-portion', type=float, default=1.0,
+                        help='The portion of the dataset to use for calibration stats collection.')
+    parser.add_argument('--qe-calib-batchsize', type=int, default=256,
+                        help='The portion of the dataset to use for calibration stats collection.')
+    parser.add_argument('--base-score', type=float, default=None)
+    parser.add_argument('--quantize-inputs', type=str, nargs='+', metavar='LAYER_NAME#INPUT_IDX', default=[],
+                        help='The inputs of layers to quantize')
+    parser.add_argument('--resume-search-from', type=str, help='Search checkpoint file to resume.',
+                        default=None)
+    args = parser.parse_args()
+    return args
+
+
+def override_odict(**kwargs):
+    return OrderedDict(kwargs)
+
+
+def get_inputs_to_quantize(sg, args, recurrent=False):
+    """
+    Finds the modules in the graph that take the user input directly
+    Args:
+        sg (SummaryGraph): the summary graph of the model
+        recurrent: see SummaryGraph.layers_topological_order
+    TODO - implement properly
+    """
+    # input_modules = set()
+    # layers = set(sg.layers_topological_order(recurrent))
+    # for op in sg.top_level_ops():
+    #     input_modules.update(set(sg.successors(op, 1)) & layers)
+    # return list(input_modules)
+    result = defaultdict(lambda: [])
+    for input_str in args.quantize_inputs:
+        module_name, input_idx_str = input_str.split('#')
+        input_idx = int(input_idx_str)
+        result[module_name].append(input_idx)
+    return result
+
+
+def input_override_generator(module, module_name, sg, overrides_dict, **kwargs):
+    """
+    Generator for overrides on inputs of the input layers.
+    Args:
+        module (nn.Module): the module
+        module_name (str): module name as it appears in the summary graph
+        sg (SummaryGraph): a summary graph of the model
+        overrides_dict (OrderedDict): the fixed overrides already applied
+        kwargs: additional arguments, if needed
+    """
+    bits_acts = kwargs.get('bits_activations', 8)
+    bits_wts = kwargs.get('bits_weights', 8)
+    input_nodes = sg.predecessors(module_name, 1)
+    input_idx = kwargs.get('input_idx', 0)
+    assert input_idx < len(input_nodes)
+    for clip_mode in CLIP_MODES:
+        input_idx_override = override_odict(bits_activations=bits_acts,
+                                            clip_acts=clip_mode)
+        input_overrides = OrderedDict([(input_idx, input_idx_override)])
+        current_module_override = override_odict(input_overrides=input_overrides)
+        # add basic quantization so the quantizer doesn't reject this override
+        current_module_override['bits_activations'] = bits_acts
+        if isinstance(module, PARAM_MODULES):
+            current_module_override['bits_weights'] = bits_wts
+        yield current_module_override
+
+
+def module_override_generator(module, module_name, sg, overrides_dict, **kwargs):
+    """
+    Standard generator of overrides for the greedy search algorithm.
+    Args:
+        module (nn.Module): the module
+        module_name (str): module name as it appears in the summary graph
+        sg (SummaryGraph): a summary graph of the model
+        overrides_dict (OrderedDict): the fixed overrides already applied
+        kwargs: additional arguments, if needed
+    """
+    bits_acts = kwargs.get('bits_activations', 8)
+    bits_wts = kwargs.get('bits_weights', 8)
+    if isinstance(module, nn.ReLU):
+        yield override_odict(make_identity=True,
+                             bits_weights=bits_wts,
+                             bits_activations=bits_acts)
+        return
+    adj_map = sg.adjacency_map()
+    modules_dict = dict(sg._src_model.named_modules())
+    successors_names = {op.name for op in adj_map[module_name].successors if op.name in modules_dict}
+    use_half_range = all([isinstance(modules_dict[succ], nn.ReLU) for succ in successors_names])
+    use_fake = False
+    fpq_module = None
+    if isinstance(module, FP16_LAYERS):
+        fpq_module = 16
+        use_fake = True
+    if isinstance(module, UNQUANTIZED_MODULES) or not isinstance(module, QUANTIZED_MODULES):
+        fpq_module = 32
+        use_fake = True
+    for clip_mode in CLIP_MODES:
+        if isinstance(module, PARAM_MODULES):
+            current_module_override = override_odict(clip_acts=clip_mode,
+                                                     bits_weights=bits_wts,
+                                                     bits_activations=bits_acts,
+                                                     bits_bias=32)
+        else:
+            current_module_override = override_odict(clip_acts=clip_mode,
+                                                     fpq_module=fpq_module,
+                                                     fake=use_fake,
+                                                     bits_weights=bits_wts,
+                                                     bits_activations=bits_acts)
+        current_module_override['clip_half_range'] = use_half_range and clip_mode in ['GAUSS', 'LAPLACE']
+
+        yield current_module_override
+
+
+def search_best_local_settings(module, module_name, sg, act_stats, eval_fn, best_overrides_dict, override_gen_fn,
+                               **kwargs):
+    msglogger.info('Searching optimal quantization in \'%s\'(%s):' % (module_name, module.__class__.__name__))
+    overrides_dict = deepcopy(best_overrides_dict)
+    best_performance, best_local_override = float("-inf"), OrderedDict()
+    normalized_module_name = module_name
+    if isinstance(model, nn.DataParallel):
+        normalized_module_name = re.sub(r'module\.', '', normalized_module_name)
+    for local_override in override_gen_fn(module, module_name, sg, best_overrides_dict, **kwargs):
+        if not overrides_dict.get(normalized_module_name, None):
+            overrides_dict[normalized_module_name] = OrderedDict()
+        overrides_dict[normalized_module_name].update(local_override)
+        temp_act_stats = deepcopy(act_stats)
+        quantizer = PostTrainLinearQuantizer(deepcopy(model),
+                                             bits_activations=None,
+                                             bits_parameters=None,
+                                             bits_accum=32,
+                                             mode=LinearQuantMode.ASYMMETRIC_SIGNED,
+                                             clip_acts=ClipMode.NONE,
+                                             overrides=deepcopy(overrides_dict),
+                                             model_activation_stats=deepcopy(temp_act_stats),
+                                             inputs_quant_auto_fallback=False,
+                                             per_channel_wts=kwargs.get('per_channel', False))
+        quantizer.prepare_model(dummy_input)
+
+        current_performance = eval_fn(quantizer.model)
+        if not isinstance(module, UNQUANTIZED_MODULES):
+            clip_mode = local_override.get('clip_acts', None)
+            msglogger.info('\t%s\t score = %.3f\tLayer overrides: %s' %
+                           (clip_mode or '', current_performance, local_override))
+        else:
+            msglogger.info('\t Module is not quantized to int8. Not clipping activations.')
+            msglogger.info('\t score = %.3f\tLayer overrides: %s' %
+                           (current_performance, local_override))
+        if current_performance > best_performance:
+            best_performance = current_performance
+            best_local_override = local_override
+
+    msglogger.info('\t Choosing overrides: %s' % best_local_override)
+    return best_local_override
+
+
+def ptq_greedy_search(model, dummy_input, eval_fn, calib_eval_fn=None,
+                      recurrent=False, act_stats=None,
+                      args=None,
+                      module_override_gen_fn=None, input_override_gen_fn=None,
+                      fold_sequences=True):
+    """
+    Perform greedy search on Post Train Quantization configuration for the model.
+    Args:
+        model (nn.Module): the model to quantize
+        dummy_input (torch.Tensor): a dummy input to be passed to the model
+        eval_fn (function): Test/Evaluation function for the model. It must have an argument named 'model' that
+          accepts the model. All other arguments should be set in advance (can be done using functools.partial), or
+          they will be left with their default values.
+        calib_eval_fn (function): An 'evaluation' function to use for forward passing
+          through the model to collection quantization calibration statistics.
+          if None provided - will use `eval_fn` as a default.
+        recurrent (bool): a flag to indicate whether the model has recurrent connections.
+        act_stats (OrderedDict): quant calibration activation stats.
+          if None provided - will be calculated on runtime.
+        args (dict or argparse.Namespace): command line arguments. alternatively - a dict.
+        module_override_gen_fn: A function to generate module overrides.
+          assumes signature
+          `def module_override_gen_fn(module: nn.Module,
+                                      module_name: str,
+                                      sg: distiller.SummaryGraph,
+                                      overrides_dict: OrderedDict,
+                                      **kwargs)-> Generator[OrderedDict, None, None]`
+        input_override_gen_fn: Same as module_override_gen_fn, only quantized inputs to the top level layers.
+        fold_sequences (bool): fold batch norms before quantizing
+    Returns:
+        (quantized_model, best_overrides_dict)
+    Note:
+        It is assumed that `eval_fn` returns a satisfying metric of performance (e.g. accuracy)
+        and the greedy search aims to maximize this metric.
+    """
+    if args is None:
+        args = get_default_args()
+    elif isinstance(args, dict):
+        updated_args = get_default_args()
+        updated_args.__dict__.update(args)
+        args = updated_args
+
+    if fold_sequences:
+        model = fold_batch_norms(model, dummy_input)
+    best_overrides_dict = OrderedDict()
+    if args.resume_search_from:
+        with open(args.resume_search_from, 'r') as f:
+            best_overrides_dict = distiller.yaml_ordered_load(f)
+        msglogger.info('Loaded search checkpoint from %s' % args.resume_search_from)
+    overrides_dict = OrderedDict()
+    sg = SummaryGraph(model, dummy_input)
+    modules_to_quantize = sg.layers_topological_order(recurrent)
+    adjacency_map = sg.adjacency_map()
+    modules_dict = OrderedDict(model.named_modules())  # type: OrderedDict[str, nn.Module]
+    modules_to_quantize = [m for m in modules_to_quantize
+                           if m not in args.qe_no_quant_layers]
+
+    module_override_gen_fn = module_override_gen_fn or module_override_generator
+    input_override_gen_fn = input_override_gen_fn or input_override_generator
+
+    calib_eval_fn = calib_eval_fn or eval_fn
+    if not act_stats:
+        msglogger.info('Collecting stats for model...')
+        model_temp = distiller.utils.make_non_parallel_copy(model)
+        act_stats = collect_quant_stats(model_temp, calib_eval_fn)
+        del model_temp
+        if args:
+            act_stats_path = '%s_act_stats.yaml' % args.arch
+            msglogger.info('Done. Saving act stats into %s' % act_stats_path)
+            distiller.yaml_ordered_save(act_stats_path, act_stats)
+    msglogger.info('Evaluating baseline score for model...')
+    base_score = args.base_score or eval_fn(model)
+    msglogger.info("Base score: %.3f" % base_score)
+
+    def recalibrate_stats(module_name, act_stats):
+        """
+        Re-collects quant-calibration stats for successor modules of the current module.
+        """
+        msglogger.info('Recalibrating stats...')
+        modules_to_recalibrate = {op.name for op in adjacency_map[module_name].successors} & set(act_stats)
+        if not modules_to_recalibrate:
+            # either there aren't any successors or
+            # the successors aren't in the stats file - skip
+            return act_stats
+        q = PostTrainLinearQuantizer(distiller.utils.make_non_parallel_copy(model),
+                                     bits_activations=None,
+                                     bits_parameters=None,
+                                     bits_accum=32,
+                                     mode=LinearQuantMode.ASYMMETRIC_SIGNED,
+                                     clip_acts=ClipMode.NONE,
+                                     overrides=deepcopy(best_overrides_dict),
+                                     model_activation_stats=deepcopy(act_stats),
+                                     inputs_quant_auto_fallback=False,
+                                     per_channel_wts=args.qe_per_channel)
+        q.prepare_model(dummy_input)
+        # recalibrate on the current best quantized version of the model.
+        recalib_act_stats = collect_quant_stats(q.model, calib_eval_fn, modules_to_collect=modules_to_recalibrate)
+        msglogger.info('Done.')
+        act_stats.update(recalib_act_stats)
+        return act_stats
+
+    loaded_from_checkpoint = []
+    # Quantize inputs:
+    input_modules = get_inputs_to_quantize(sg, args, recurrent)  # top level modules
+    for module_name, input_idxs in input_modules.items():
+        denormalized_module_name = distiller.denormalize_module_name(model, module_name)
+        module = modules_dict[denormalized_module_name]
+        if isinstance(module, SKIP_MODULES):
+            msglogger.info('Skipping module \'%s\' of type %s.' % (module_name, type(module)))
+            continue
+        msglogger.info('Quantizing top level inputs for %s' % module_name)
+
+        normalized_module_name = module_name
+        if isinstance(model, nn.DataParallel):
+            normalized_module_name = re.sub(r'module\.', '', normalized_module_name)
+        if normalized_module_name in best_overrides_dict and \
+                best_overrides_dict[normalized_module_name].get('input_overrides', None):
+            # This means the loaded dict already has the module
+            msglogger.info("  Quantizing '%s' based on loaded checkpoint: %s" %
+                           (module_name, best_overrides_dict[normalized_module_name]))
+            if best_overrides_dict[normalized_module_name].get('bits_activations'):
+                loaded_from_checkpoint.append(normalized_module_name)
+            continue
+        if not best_overrides_dict.get(normalized_module_name, None):
+            best_overrides_dict[normalized_module_name] = OrderedDict()
+        for input_idx in input_idxs:
+            best_module_override = search_best_local_settings(module, module_name, sg, act_stats, eval_fn,
+                                                              best_overrides_dict,
+                                                              input_override_gen_fn, input_idx=input_idx,
+                                                              bits_activations=args.qe_bits_acts,
+                                                              bits_weights=args.qe_bits_wts,
+                                                              per_channel=args.qe_per_channel)
+            best_overrides_dict[normalized_module_name].update(best_module_override)
+        # Leave only the input_overrides settings:
+        current_input_overrides = best_overrides_dict[normalized_module_name]['input_overrides']
+        best_overrides_dict[normalized_module_name] = override_odict(input_overrides=current_input_overrides)
+
+    # Quantize layers as a whole:
+    for module_name in modules_to_quantize:
+        module = modules_dict[module_name]
+        if isinstance(module, SKIP_MODULES):
+            msglogger.info('Skipping module \'%s\' of type %s.' % (module_name, module.__class__.__name__))
+            continue
+
+        normalized_module_name = module_name
+        if isinstance(model, nn.DataParallel):
+            normalized_module_name = re.sub(r'module\.', '', normalized_module_name)
+
+        if normalized_module_name in best_overrides_dict and \
+                best_overrides_dict[normalized_module_name].get('bits_activations', None)\
+                and normalized_module_name not in loaded_from_checkpoint:
+            # This means the loaded dict already has the module
+            msglogger.info("  Quantizing '%s'(%s) based on loaded checkpoint: %s" %
+                           (module_name, module.__class__.__name__, best_overrides_dict[normalized_module_name]))
+            loaded_from_checkpoint.append(normalized_module_name)
+            continue
+        if not best_overrides_dict.get(normalized_module_name, None):
+            best_overrides_dict[normalized_module_name] = OrderedDict()
+        # Hard coded workaround for avgpool->reshape->fc
+        if normalized_module_name == 'fc':
+            input_override = override_odict(bits_activations=8,
+                                            clip_acts='NONE')
+            best_overrides_dict['fc'].update(OrderedDict([
+                ('input_overrides', OrderedDict([
+                    (0, input_override)
+                ]))
+            ]))
+        best_module_override = search_best_local_settings(module, module_name, sg, act_stats, eval_fn,
+                                                          best_overrides_dict,
+                                                          module_override_gen_fn,
+                                                          bits_activations=args.qe_bits_acts,
+                                                          bits_weights=args.qe_bits_wts,
+                                                          per_channel=args.qe_per_channel)
+        best_overrides_dict[normalized_module_name].update(best_module_override)
+        distiller.yaml_ordered_save('%s.ptq_greedy_search.yaml' % args.arch, best_overrides_dict)
+        # # end of search - we update the calibration of the next layers:
+        # recalibrate_stats(module_name, act_stats)
+
+    quantizer = PostTrainLinearQuantizer(model,
+                                         bits_activations=None,
+                                         bits_parameters=None,
+                                         bits_accum=32,
+                                         mode=LinearQuantMode.ASYMMETRIC_SIGNED,
+                                         clip_acts=ClipMode.NONE,
+                                         overrides=deepcopy(best_overrides_dict),
+                                         model_activation_stats=act_stats,
+                                         inputs_quant_auto_fallback=False,
+                                         per_channel_wts=args.qe_per_channel)
+    quantizer.prepare_model(dummy_input)
+    msglogger.info('best_overrides_dict: %s' % best_overrides_dict)
+    msglogger.info('Best score: %f'% eval_fn(quantizer.model))
+    return model, best_overrides_dict
+
+
+if __name__ == "__main__":
+    args = get_default_args()
+    args.epochs = float('inf')  # hack for args parsing so there's no error in epochs
+    cc = classifier.ClassifierCompressor(args, script_dir=os.path.dirname(__file__))
+    eval_data_loader = classifier.load_data(args, load_train=False, load_val=False)
+
+    # quant calibration dataloader:
+    args.effective_test_size = args.qe_calib_portion
+    args.batch_size = args.qe_calib_batchsize
+    calib_data_loader = classifier.load_data(args, load_train=False, load_val=False)
+    # logging
+    logging.getLogger().setLevel(logging.WARNING)
+    msglogger = logging.getLogger(__name__)
+    msglogger.setLevel(logging.INFO)
+
+    def test_fn(model):
+        top1, top5, losses = classifier.test(eval_data_loader, model, cc.criterion, [cc.tflogger, cc.pylogger], None,
+                                             args)
+        return top1
+
+    def calib_eval_fn(model):
+        classifier.test(calib_data_loader, model, cc.criterion, [], None,
+                        args)
+
+    model = create_model(args.pretrained, args.dataset, args.arch,
+                         parallel=not args.load_serialized, device_ids=args.gpus)
+    args.device = next(model.parameters()).device
+    if args.resumed_checkpoint_path:
+        args.load_model_path = args.resumed_checkpoint_path
+    if args.load_model_path:
+        msglogger.info("Loading checkpoint from %s" % args.load_model_path)
+        model = apputils.load_lean_checkpoint(model, args.load_model_path,
+                                              model_device=args.device)
+    dummy_input = torch.rand(*model.input_shape, device=args.device)
+    if args.qe_stats_file:
+        msglogger.info("Loading stats from %s" % args.qe_stats_file)
+        with open(args.qe_stats_file, 'r') as f:
+            act_stats = distiller.yaml_ordered_load(f)
+    else:
+        act_stats = None
+    model, overrides = ptq_greedy_search(model, dummy_input, test_fn,
+                                         calib_eval_fn=calib_eval_fn, args=args,
+                                         act_stats=act_stats)
+    # Prepare a compression scheduler yaml config file:
+    quantizer_dict = OrderedDict([
+        ('class', 'PostTrainLinearQuantizer')
+    ])
+    quantizer_dict.update(deepcopy(model.quantizer_metadata['params']))
+    quantizer_dict['overrides'] = overrides
+    quantizer_dict['model_activation_stats'] = os.path.abspath('%s_act_stats.yaml' % args.arch)
+    sched_dict = OrderedDict([
+        ('quantizers', OrderedDict([
+            ('post_train_quantizer', quantizer_dict)
+        ]))
+    ])
+    distiller.yaml_ordered_save('%s.ptqgs_quantizer_sched_dict.yaml' % args.arch, sched_dict)
diff --git a/distiller/quantization/quantizer.py b/distiller/quantization/quantizer.py
index 2b3ac788e21b034fd8d87d38ea464825926b6309..84d0a246c44bf0b784a9bb58bd2228924fe78681 100644
--- a/distiller/quantization/quantizer.py
+++ b/distiller/quantization/quantizer.py
@@ -102,6 +102,10 @@ class Quantizer(object):
                 3.1 Gradients calculated with respect to q_weights
                 3.2 We also back-prop through the 'quantize' operation from step 1
             4. Update fp_weights with gradients calculated in step 3.2
+        Note:
+            The `overrides` dictionary assumes the keys are *not* the module names in the
+            `nn.DataParallel` case - i.e. without the `module.` prefix. e.g.:
+            module.conv1 -> OrderedDict([('conv1', OrderedDict(...))])
     """
     def __init__(self, model, optimizer=None,
                  bits_activations=None, bits_weights=None, bits_bias=None,
@@ -147,7 +151,7 @@ class Quantizer(object):
             regex_overrides_str = '|'.join(['(^{0}$)'.format(pattern) for pattern in patterns])
             regex_overrides = re.compile(regex_overrides_str)
 
-        self.module_qbits_map = {}
+        self.module_qbits_map = {}  # type: OrderedDict[str, QBits]
         self.module_overrides_map = {}
         for module_full_name, module in model.named_modules():
             # Need to account for scenario where model is parallelized with DataParallel, which wraps the original
@@ -171,7 +175,8 @@ class Quantizer(object):
         # Mapping from module type to function generating a replacement module suited for quantization
         # To be populated by child classes
         # Unspecified layer types return None by default.
-        self.replacement_factory = defaultdict(lambda: None)
+        self.replacement_factory = OrderedDict([(nn.Identity, None)])
+        self.default_repalcement_fn = None
         # Pointer to parameters quantization function, triggered during training process
         # To be populated by child classes
         self.param_quantization_fn = None
@@ -252,6 +257,8 @@ class Quantizer(object):
         # Re-transfer model to the device it was on, in case the quantizer created new parameters/buffers
         self.model.to(model_device)
 
+        distiller.assign_layer_fq_names(self.model)
+
         msglogger.info('Quantized model:\n\n{0}\n'.format(self.model))
 
     def _pre_prepare_model(self, dummy_input):
@@ -281,15 +288,15 @@ class Quantizer(object):
                     replace_msg(full_name)
                 continue
             current_qbits = self.module_qbits_map[full_name]
-            if current_qbits.acts is None and current_qbits.wts is None:
-                if self.module_overrides_map[full_name]:
-                    raise ValueError("Adding overrides while not quantizing is not allowed.")
+            # TODO - Review necessity of the block below
+            if current_qbits.acts is None and current_qbits.wts is None and not self.module_overrides_map[full_name]:
                 # We indicate this module wasn't replaced by a wrapper
                 replace_msg(full_name)
                 self.modules_processed[module] = full_name, None
             else:
                 # We use a type hint comment to let IDEs know replace_fn is a function
-                replace_fn = self.replacement_factory[type(module)]  # type: Optional[Callable]
+                replace_fn = self.replacement_factory.get(type(module),
+                                                          self.default_repalcement_fn)  # type: Optional[Callable]
                 # If the replacement function wasn't specified - continue without replacing this module.
                 if replace_fn is not None:
                     valid_kwargs, invalid_kwargs = distiller.filter_kwargs(self.module_overrides_map[full_name],
@@ -299,16 +306,21 @@ class Quantizer(object):
                                             as override arguments for %s. Allowed kwargs: %s"""
                                         % (type(self), list(invalid_kwargs), type(module), list(valid_kwargs)))
                     new_module = replace_fn(module, full_name, self.module_qbits_map, **valid_kwargs)
-                    replace_msg(full_name, (module, new_module))
-                    # Add to history of prepared submodules
-                    self.modules_processed[module] = full_name, new_module
-                    setattr(container, name, new_module)
-
-                    # If a "leaf" module was replaced by a container, add the new layers to the QBits mapping
-                    if not distiller.has_children(module) and distiller.has_children(new_module):
-                        for sub_module_name, sub_module in new_module.named_modules():
-                            self._add_qbits_entry(full_name + '.' + sub_module_name, type(sub_module), current_qbits)
-                        self.module_qbits_map[full_name] = QBits(acts=current_qbits.acts, wts=None, bias=None)
+                    if new_module != module:
+                        replace_msg(full_name, (module, new_module))
+                        # Add to history of prepared submodules
+                        self.modules_processed[module] = full_name, new_module
+                        setattr(container, name, new_module)
+
+                        # If a "leaf" module was replaced by a container, add the new layers to the QBits mapping
+                        if not distiller.has_children(module) and distiller.has_children(new_module):
+                            for sub_module_name, sub_module in new_module.named_modules():
+                                self._add_qbits_entry(full_name + '.' + sub_module_name, type(sub_module),
+                                                      current_qbits)
+                            self.module_qbits_map[full_name] = QBits(acts=current_qbits.acts, wts=None, bias=None)
+                    else:
+                        replace_msg(full_name)
+                        self.modules_processed[module] = full_name, None
 
             if distiller.has_children(module):
                 # For container we call recursively
diff --git a/distiller/quantization/range_linear.py b/distiller/quantization/range_linear.py
index 58b4a187d7afa1ab305599d14d94769d8492ac50..703e8721848e36d9d4803021849ffe42082eef15 100644
--- a/distiller/quantization/range_linear.py
+++ b/distiller/quantization/range_linear.py
@@ -15,9 +15,10 @@
 #
 
 import torch.nn as nn
+import torch.nn.functional as f
 import argparse
 from enum import Enum
-from collections import OrderedDict
+from collections import OrderedDict, namedtuple
 from functools import reduce, partial, update_wrapper
 import logging
 import os
@@ -26,15 +27,28 @@ import warnings
 
 import distiller
 import distiller.utils
-from .quantizer import Quantizer
+from .quantizer import Quantizer, QBits
 from .q_utils import *
+from .sim_bn_fold import SimulatedFoldedBatchNorm
 import distiller.modules
 import distiller.model_transforms as mt
 
 msglogger = logging.getLogger()
 
 
+def _quant_param_to_str(val):
+    if isinstance(val, torch.Tensor):
+        if val.numel() > 1:
+            return 'PerCh'
+        else:
+            return '{:.6f}'.format(val.item())
+    return '{:.6f}'.format(val)
+
+
 def _enum_to_str(enum_val):
+    # TODO: This can probably be removed
+    if isinstance(enum_val, str): # temporary fix
+        return enum_val
     return str(enum_val).split('.')[1]
 
 
@@ -51,7 +65,6 @@ class ClipMode(Enum):
     AVG = 1
     # Clip value calculated as mean of tensor + N standard deviations. N should be specified separately
     N_STD = 2
-
     # ACIQ Clipping Modes -
     GAUSS = 3
     LAPLACE = 4
@@ -95,6 +108,7 @@ def _get_saturation_fn(quant_mode, clip_mode, num_stds, num_bits=None):
     return fns[clip_mode]
 
 
+# TODO: Move to q_utils, add tests
 def _get_quant_params_from_tensor(tensor, num_bits, mode, clip=ClipMode.NONE, per_channel=False, num_stds=None,
                                   half_range=False, scale_approx_mult_bits=None):
     if per_channel and tensor.dim() not in [2, 4]:
@@ -105,9 +119,6 @@ def _get_quant_params_from_tensor(tensor, num_bits, mode, clip=ClipMode.NONE, pe
             raise ValueError('N_STD clipping not supported with per-channel quantization')
         if num_stds is None:
             raise ValueError('Clip mode set top N_STD but \'num_stds\' parameter not provided')
-    if half_range and clip not in [ClipMode.GAUSS, ClipMode.LAPLACE]:
-        warnings.warn("Using clip_half_range without ACIQ clip modes (GAUSS or LAPACE) will have no"
-                      " effect.")
 
     dim = 0 if clip == ClipMode.AVG or per_channel else None
     sat_fn = _get_saturation_fn(mode, clip, num_stds, num_bits)
@@ -132,6 +143,7 @@ def _get_quant_params_from_tensor(tensor, num_bits, mode, clip=ClipMode.NONE, pe
     return scale, zp
 
 
+# TODO: Move to q_utils, add tests
 def _get_quant_params_from_stats_dict(stats, num_bits, mode, clip=ClipMode.NONE, num_stds=None, half_range=False,
                                       scale_approx_mult_bits=None):
     if clip == ClipMode.N_STD:
@@ -139,9 +151,6 @@ def _get_quant_params_from_stats_dict(stats, num_bits, mode, clip=ClipMode.NONE,
             raise ValueError('Clip mode set to N_STD but \'num_stds\' parameter not provided')
         if num_stds <= 0:
             raise ValueError('n_stds must be > 0, got {}'.format(num_stds))
-    if half_range and clip not in [ClipMode.GAUSS, ClipMode.LAPLACE]:
-        warnings.warn("Using clip_half_range without ACIQ clip modes (GAUSS or LAPACE) will have no"
-                      " effect.")
 
     prefix = 'avg_' if clip == ClipMode.AVG else ''
     sat_min = torch.tensor(float(stats[prefix + 'min']))
@@ -174,6 +183,39 @@ def _get_quant_params_from_stats_dict(stats, num_bits, mode, clip=ClipMode.NONE,
 # Post Training
 ###############################################################################
 
+class TensorQuantMetadata(namedtuple('TensorQuantMetadata', ['scale', 'zero_point', 'min_q_val', 'max_q_val'])):
+    __slots__ = ()
+
+    def __str__(self):
+        return '(scale={} ; zero_point={})'.format(_quant_param_to_str(self.scale),
+                                                   _quant_param_to_str(self.zero_point))
+
+
+class QuantSettings(object):
+    def __init__(self, num_bits, quant_mode, clip_mode, clip_n_stds, clip_half_range, per_channel):
+        self.num_bits = num_bits
+        self.quant_mode = quant_mode
+        self.clip_mode = clip_mode
+        self.clip_n_stds = clip_n_stds
+        self.clip_half_range = clip_half_range
+        self.per_channel = per_channel
+
+    def __str__(self):
+        return '(num_bits={} ; quant_mode={} ; clip_mode={} ; clip_n_stds={} ; clip_half_range={}' \
+               ' ; per_channel={})'.format(self.num_bits, _enum_to_str(self.quant_mode),
+                                           _enum_to_str(self.clip_mode), self.clip_n_stds, self.clip_half_range,
+                                           self.per_channel
+        )
+
+
+def linear_quantize_clamp_with_metadata(t, inplace=False):
+    return linear_quantize_clamp(t, *t.quant_metadata, inplace)
+
+
+def linear_dequantize_with_metadata(t, inplace=False):
+    qmd = t.quant_metadata
+    return linear_dequantize(t, qmd.scale, qmd.zero_point, inplace)
+
 
 def add_post_train_quant_args(argparser):
     str_to_quant_mode_map = {'sym': LinearQuantMode.SYMMETRIC,
@@ -254,17 +296,42 @@ class RangeLinearQuantWrapper(nn.Module):
 
     def __init__(self, wrapped_module, num_bits_acts, num_bits_accum=32, mode=LinearQuantMode.SYMMETRIC,
                  clip_acts=ClipMode.NONE, activation_stats=None, clip_n_stds=None, clip_half_range=False,
-                 scale_approx_mult_bits=None):
+                 scale_approx_mult_bits=None,
+                 input_overrides=None, requires_quantized_inputs=True, inputs_quant_auto_fallback=False):
         super(RangeLinearQuantWrapper, self).__init__()
 
+        input_overrides = input_overrides or OrderedDict()
+
         self.wrapped_module = wrapped_module
-        self.num_bits_acts = num_bits_acts
-        self.num_bits_accum = num_bits_accum
-        self.mode = mode
-        self.clip_acts = clip_acts
-        self.clip_n_stds = clip_n_stds
         self.clip_half_range = clip_half_range
         self.scale_approx_mult_bits = scale_approx_mult_bits
+        self.requires_quantized_inputs = requires_quantized_inputs
+        self.inputs_quant_auto_fallback = inputs_quant_auto_fallback
+
+        self.output_quant_settings = QuantSettings(num_bits_acts, mode, clip_acts, clip_n_stds, clip_half_range, False)
+        self.accum_quant_settings = QuantSettings(num_bits_accum, LinearQuantMode.SYMMETRIC,
+                                                  ClipMode.NONE, None, False, False)
+        if self.requires_quantized_inputs:
+            self.inputs_quant_settings_overrides = OrderedDict()
+            for k, v in input_overrides.items():
+                idx = int(k)
+                if v.pop('from_output', None):
+                    quant_settings = deepcopy(self.output_quant_settings)
+                    quant_settings.clip_half_range = False
+                else:
+                    quant_settings = QuantSettings(v.pop('bits_activations', self.output_quant_settings.num_bits),
+                                                   verify_quant_mode(
+                                                       v.pop('mode', self.output_quant_settings.quant_mode)),
+                                                   verify_clip_mode(
+                                                       v.pop('clip_acts', self.output_quant_settings.clip_mode)),
+                                                   v.pop('clip_n_stds', self.output_quant_settings.clip_n_stds),
+                                                   False, False)
+                    if v:
+                        # Poor man's input checking on input overrides dict
+                        raise ValueError('Input overrides dict contains unsupported keys:', list(v.keys()))
+                self.inputs_quant_settings_overrides[idx] = quant_settings
+        else:
+            self.inputs_quant_settings_overrides = None
 
         # Controls whether output is de-quantized at end of forward op. Meant as a debug / test flag only
         # (note that if False, the quantized output will be returned, but without any quantization parameters,
@@ -278,58 +345,59 @@ class RangeLinearQuantWrapper(nn.Module):
 
         if activation_stats:
             self.preset_act_stats = True
-            self.num_inputs = 0
-            for idx, stats in activation_stats['inputs'].items():
-                self.num_inputs += 1
-                scale, zp = _get_quant_params_from_stats_dict(stats, num_bits_acts,
-                                                              mode, clip_acts, clip_n_stds, clip_half_range,
-                                                              scale_approx_mult_bits)
-                prefix = 'in_{0}_'.format(idx)
-                self.register_buffer(prefix + 'scale', scale)
-                self.register_buffer(prefix + 'zero_point', zp)
-
-            scale, zp = _get_quant_params_from_stats_dict(activation_stats['output'], num_bits_acts, mode,
-                                                          clip_acts, clip_n_stds, clip_half_range,
-                                                          scale_approx_mult_bits)
+
+            if self.requires_quantized_inputs:
+                self.inputs_quant_metadata_fallback = OrderedDict()
+                for idx, stats in activation_stats['inputs'].items():
+                    settings = self.inputs_quant_settings_overrides.get(idx, self.output_quant_settings)
+                    scale, zp = _get_quant_params_from_stats_dict(
+                        stats, settings.num_bits, settings.quant_mode, settings.clip_mode,
+                        settings.clip_n_stds, settings.clip_half_range, self.scale_approx_mult_bits
+                    )
+                    min_q_val, max_q_val = get_quantized_range(
+                        settings.num_bits, settings.quant_mode != LinearQuantMode.ASYMMETRIC_UNSIGNED)
+                    qmd = TensorQuantMetadata(scale, zp, min_q_val, max_q_val)
+                    self.inputs_quant_metadata_fallback[idx] = qmd
+            else:
+                self.inputs_quant_metadata_fallback = None
+
+            scale, zp = _get_quant_params_from_stats_dict(activation_stats['output'], num_bits_acts, mode, clip_acts,
+                                                          clip_n_stds, clip_half_range, scale_approx_mult_bits)
             self.register_buffer('output_scale', scale)
             self.register_buffer('output_zero_point', zp)
         else:
             self.preset_act_stats = False
 
-    def inputs_scales(self):
-        for scale in self._inputs_qparam('scale'):
-            yield scale
-
-    def inputs_zero_points(self):
-        for zp in self._inputs_qparam('zero_point'):
-            yield zp
-
-    def _inputs_qparam(self, type_str):
-        if type_str not in ['scale', 'zero_point']:
-            raise ValueError('Unknown quantization parameter type')
-        if not self.preset_act_stats:
-            raise RuntimeError('Input quantization parameter iterators only available when activation stats were given')
-        for idx in range(self.num_inputs):
-            name = 'in_{0}_{1}'.format(idx, type_str)
-            yield getattr(self, name)
+        self.register_buffer('num_forwards', torch.zeros(1, dtype=torch.long))
 
     def forward(self, *inputs):
         if self.training:
             raise RuntimeError(self.__class__.__name__ + " can only be used in eval mode")
+
         device = inputs[0].device
         for buffer_name, buffer in self._buffers.items():
             setattr(self, buffer_name, buffer.to(device))
 
-        in_scales, in_zero_points = self.get_inputs_quantization_params(*inputs)
-
-        # Quantize inputs
-        inputs_q = [linear_quantize_clamp(input.data, scale, zp,
-                                          self.acts_min_q_val, self.acts_max_q_val, inplace=False)
-                    for input, scale, zp in zip(inputs, in_scales, in_zero_points)]
+        if self.requires_quantized_inputs:
+            self._prepare_inputs_for_quantization(inputs)
+
+            inputs_q = []
+            for input in inputs:
+                qmd = input.quant_metadata
+                input.quant_metadata = TensorQuantMetadata(qmd.scale.to(device), qmd.zero_point.to(device),
+                                                                        qmd.min_q_val, qmd.max_q_val)
+                input_q = linear_quantize_clamp_with_metadata(input)
+                input_q.quant_metadata = input.quant_metadata
+                inputs_q.append(input_q)
+        else:
+            inputs_q = inputs
 
         # Forward through wrapped module
         accum = self.quantized_forward(*inputs_q)
 
+        if self.clip_half_range:
+            accum = f.relu(accum)
+
         # Re-quantize accumulator to quantized output range
         out_scale, out_zero_point = self.get_output_quantization_params(accum)
         requant_scale, requant_zero_point = self.get_accum_to_output_re_quantization_params(out_scale, out_zero_point)
@@ -342,18 +410,48 @@ class RangeLinearQuantWrapper(nn.Module):
         # De-quantize back to FP32
         out_f = linear_dequantize(out_q, out_scale, out_zero_point, inplace=True)
 
-        return out_f
+        out_f.quant_metadata = TensorQuantMetadata(out_scale, out_zero_point, self.acts_min_q_val, self.acts_max_q_val)
 
-    def get_inputs_quantization_params(self, *inputs):
-        """
-        Calculate input quantization parameters (scale and zero-point)
+        self.num_forwards += 1
 
-        Should be overridden by all subclasses
+        return out_f
 
-        :param inputs: Current input tensors passed to forward method
-        :return: Tuple of 2 lists - list of scales per input and list of zero-point per input
-        """
-        raise NotImplementedError
+    def _prepare_inputs_for_quantization(self, inputs):
+        for idx, input in enumerate(inputs):
+            if hasattr(input, 'quant_metadata'):
+                if idx in self.inputs_quant_settings_overrides:
+                    raise RuntimeError('<{}> Input {}: CONFLICT - Tensor has embedded quantization metadata AND user '
+                                       'defined input quantization settings'.format(self.distiller_name, idx))
+            else:
+                # Input doesn't have embedded quantization data propagated from a previous layer
+                # Our options are:
+                #  If user set explicit settings for this input, use those
+                #  OR
+                #  If auto fallback is set, use the output quantization settings
+                if idx not in self.inputs_quant_settings_overrides and not self.inputs_quant_auto_fallback:
+                    raise RuntimeError('<{}> Input {}: Expected tensor with embedded quantization metadata. Either:\n'
+                                       '1. Make sure the previous operation is quantized\n'
+                                       '2. Provide explicit input quantization settings\n'
+                                       '3. Set inputs_quant_auto_fallback'.format(self.distiller_name, idx))
+                if self.preset_act_stats:
+                    input.quant_metadata = self.inputs_quant_metadata_fallback[idx]
+                else:
+                    if idx in self.inputs_quant_settings_overrides:
+                        q_settings = self.inputs_quant_settings_overrides[idx]
+                    else:
+                        # If we're here then inputs_quant_auto_fallback is set
+                        # if self.num_forwards == 0:
+                        #     msglogger.info('<{}> Input {}: No embedded quantization metadata, '
+                        #                    'falling back to output settings'.format(self.distiller_name, idx))
+                        q_settings = deepcopy(self.output_quant_settings)
+                        q_settings.clip_half_range = False
+                    scale, zp = _get_quant_params_from_tensor(input, q_settings.num_bits, q_settings.quant_mode,
+                                                              q_settings.clip_mode, q_settings.per_channel,
+                                                              q_settings.clip_n_stds, q_settings.clip_half_range,
+                                                              self.scale_approx_mult_bits)
+                    signed = q_settings.quant_mode != LinearQuantMode.ASYMMETRIC_UNSIGNED
+                    min_q_val, max_q_val = get_quantized_range(q_settings.num_bits, signed)
+                    input.quant_metadata = TensorQuantMetadata(scale, zp, min_q_val, max_q_val)
 
     def quantized_forward(self, *inputs_q):
         """
@@ -393,19 +491,26 @@ class RangeLinearQuantWrapper(nn.Module):
         raise NotImplementedError
 
     def extra_repr(self):
-        tmpstr = 'mode={0}, '.format(str(self.mode).split('.')[1])
-        tmpstr += 'num_bits_acts={0}, num_bits_accum={1}, '.format(self.num_bits_acts, self.num_bits_accum)
-        tmpstr += 'clip_acts={0}, '.format(_enum_to_str(self.clip_acts))
-        if self.clip_acts == ClipMode.N_STD:
-            tmpstr += 'num_stds={} '.format(self.clip_n_stds)
-        tmpstr += 'scale_approx_mult_bits={}'.format(self.scale_approx_mult_bits)
+        tmpstr = 'output_quant_settings={0}'.format(self.output_quant_settings)
+        tmpstr += '\naccum_quant_settings={0}'.format(self.accum_quant_settings)
+        tmpstr += '\nrequires_quantized_inputs={0}'.format(self.requires_quantized_inputs)
+        if self.requires_quantized_inputs:
+            overrides = self.inputs_quant_settings_overrides
+            tmpstr += '\n  inputs_quant_auto_fallback={}'.format(self.inputs_quant_auto_fallback)
+            tmpstr += ', forced_quant_settings_for_inputs={}'.format(
+                'None' if not overrides else list(overrides.keys()))
+            for idx, qset in overrides.items():
+                tmpstr += '\n    input_{}_settings={}'.format(idx, qset)
+        tmpstr += '\nscale_approx_mult_bits={}'.format(self.scale_approx_mult_bits)
         tmpstr += '\npreset_activation_stats={0}'.format(self.preset_act_stats)
         if self.preset_act_stats:
-            for idx, (in_scale, in_zp) in enumerate(zip(self.inputs_scales(), self.inputs_zero_points())):
-                tmpstr += '\nin_{i}_scale={sc}, in_{i}_zero_point={zp}'.format(i=idx, sc=in_scale.item(),
-                                                                               zp=in_zp.item())
-            tmpstr += '\nout_scale={sc}, out_zero_point={zp}'.format(sc=self.output_scale.item(),
-                                                                     zp=self.output_zero_point.item())
+            tmpstr += '\n  output_scale={0}, output_zero_point={1}'.format(_quant_param_to_str(
+                self.output_scale), _quant_param_to_str(self.output_zero_point))
+            if self.requires_quantized_inputs:
+                for idx in self.inputs_quant_settings_overrides:
+                    qmd = self.inputs_quant_metadata_fallback[idx]
+                    tmpstr += '\n  input_#{0}_scale={1}, input_#{0}_zero_point={2}'.format(
+                        idx, _quant_param_to_str(qmd.scale), _quant_param_to_str(qmd.zero_point))
         return tmpstr
 
 
@@ -444,29 +549,35 @@ class RangeLinearQuantParamLayerWrapper(RangeLinearQuantWrapper):
         per_channel_wts (bool): Enable quantization of weights using separate quantization parameters per
             output channel
         activation_stats (dict): See RangeLinearQuantWrapper
-        clip_n_stds (int): See RangeLinearQuantWrapper
+        clip_n_stds (float): See RangeLinearQuantWrapper
         clip_half_range (bool) : See RangeLinearQuantWrapper
         scale_approx_mult_bits (int): See RangeLinearQuantWrapper
     """
     def __init__(self, wrapped_module, num_bits_acts, num_bits_params, num_bits_accum=32,
                  mode=LinearQuantMode.SYMMETRIC, clip_acts=ClipMode.NONE, per_channel_wts=False, activation_stats=None,
-                 clip_n_stds=None, clip_half_range=False, scale_approx_mult_bits=None):
+                 clip_n_stds=None, clip_half_range=False, scale_approx_mult_bits=None,
+                 input_overrides=None, inputs_quant_auto_fallback=False):
         super(RangeLinearQuantParamLayerWrapper, self).__init__(wrapped_module, num_bits_acts, num_bits_accum, mode,
-                                                                clip_acts, activation_stats, clip_n_stds,
-                                                                clip_half_range, scale_approx_mult_bits)
+                                                                clip_acts, activation_stats, clip_n_stds, clip_half_range,
+                                                                scale_approx_mult_bits,
+                                                                input_overrides=input_overrides,
+                                                                requires_quantized_inputs=True,
+                                                                inputs_quant_auto_fallback=inputs_quant_auto_fallback)
 
         if not isinstance(wrapped_module, (nn.Conv2d, nn.Conv3d, nn.Linear)):
             raise ValueError(self.__class__.__name__ + ' can wrap only Conv2D, Conv3D and Linear modules')
 
-        self.num_bits_params = num_bits_params
-        self.per_channel_wts = per_channel_wts
+        self.wts_quant_settings = QuantSettings(num_bits_params, mode, ClipMode.NONE, None, False, per_channel_wts)
 
         self.params_min_q_val, self.params_max_q_val = get_quantized_range(
-            num_bits_params, signed=mode != LinearQuantMode.ASYMMETRIC_UNSIGNED)
+            self.wts_quant_settings.num_bits,
+            self.wts_quant_settings.quant_mode != LinearQuantMode.ASYMMETRIC_UNSIGNED)
 
         # Quantize weights - overwrite FP32 weights
-        w_scale, w_zero_point = _get_quant_params_from_tensor(wrapped_module.weight, num_bits_params, self.mode,
-                                                              per_channel=per_channel_wts)
+        w_scale, w_zero_point = _get_quant_params_from_tensor(wrapped_module.weight,
+                                                              self.wts_quant_settings.num_bits,
+                                                              self.wts_quant_settings.quant_mode,
+                                                              per_channel=self.wts_quant_settings.per_channel)
 
         self.register_buffer('w_scale', w_scale)
         self.register_buffer('w_zero_point', w_zero_point)
@@ -477,27 +588,25 @@ class RangeLinearQuantParamLayerWrapper(RangeLinearQuantWrapper):
         device = self.w_scale.device
 
         if self.preset_act_stats:
-            self.in_0_scale = self.in_0_scale.to(device)
-            self.register_buffer('accum_scale', self.in_0_scale * self.w_scale)
-            if self.per_channel_wts:
-                self.accum_scale = self.accum_scale.squeeze(dim=-1)
+            t = torch.zeros_like(self.w_scale)
+            if self.wts_quant_settings.per_channel:
+                t = t.squeeze(dim=-1)
+            self.register_buffer('accum_scale', t)
         else:
-            self.accum_scale = 1
+            self.accum_scale = torch.ones(1).to(device)
 
         # Quantize bias
         self.has_bias = hasattr(wrapped_module, 'bias') and wrapped_module.bias is not None
-        if self.has_bias:
-            if self.preset_act_stats:
-                linear_quantize_clamp(wrapped_module.bias.data, self.accum_scale.squeeze(), 0,
-                                      self.accum_min_q_val, self.accum_max_q_val, inplace=True)
-            else:
-                b_scale, b_zero_point = _get_quant_params_from_tensor(wrapped_module.bias, num_bits_params, self.mode)
-                self.register_buffer('b_scale', b_scale)
-                self.register_buffer('b_zero_point', b_zero_point)
-                base_b_q = linear_quantize_clamp(wrapped_module.bias.data, self.b_scale, self.b_zero_point,
-                                                 self.params_min_q_val, self.params_max_q_val)
-                # Dynamic ranges - save in auxiliary buffer, requantize each time based on dynamic input scale factor
-                self.register_buffer('base_b_q', base_b_q)
+        if self.has_bias and not self.preset_act_stats:
+            b_scale, b_zero_point = _get_quant_params_from_tensor(wrapped_module.bias,
+                                                                  self.wts_quant_settings.num_bits,
+                                                                  self.wts_quant_settings.quant_mode)
+            self.register_buffer('b_scale', b_scale)
+            self.register_buffer('b_zero_point', b_zero_point)
+            base_b_q = linear_quantize_clamp(wrapped_module.bias.data, self.b_scale, self.b_zero_point,
+                                             self.params_min_q_val, self.params_max_q_val)
+            # Dynamic ranges - save in auxiliary buffer, requantize each time based on dynamic input scale factor
+            self.register_buffer('base_b_q', base_b_q)
 
         # A flag indicating that the simulated quantized weights are pre-shifted. for faster performance.
         # In the first forward pass - `w_zero_point` is added into the weights, to allow faster inference,
@@ -513,21 +622,26 @@ class RangeLinearQuantParamLayerWrapper(RangeLinearQuantWrapper):
             self.is_simulated_quant_weight_shifted.sub_(1) # i.e. is_simulated_quant_weight_shifted = False
         return super(RangeLinearQuantParamLayerWrapper, self).state_dict(destination, prefix, keep_vars)
 
-    def get_inputs_quantization_params(self, input):
-        if not self.preset_act_stats:
-            self.in_0_scale, self.in_0_zero_point = _get_quant_params_from_tensor(
-                input, self.num_bits_acts, self.mode, clip=self.clip_acts,
-                num_stds=self.clip_n_stds, scale_approx_mult_bits=self.scale_approx_mult_bits)
-        return [self.in_0_scale], [self.in_0_zero_point]
-
     def quantized_forward(self, input_q):
         # See class documentation for quantized calculation details.
 
-        if not self.preset_act_stats:
-            self.accum_scale = self.in_0_scale * self.w_scale
-            if self.per_channel_wts:
-                self.accum_scale = self.accum_scale.squeeze(dim=-1)
+        def get_accum_scale(input_q):
+            accum_scale = input_q.quant_metadata.scale * self.w_scale
+            if self.wts_quant_settings.per_channel:
+                accum_scale = accum_scale.squeeze(dim=-1)
+            if self.scale_approx_mult_bits:
+                accum_scale = approx_scale_as_mult_and_shift(accum_scale, self.scale_approx_mult_bits)
+            return accum_scale
 
+        if self.preset_act_stats:
+            if self.num_forwards == 0:
+                self.accum_scale += get_accum_scale(input_q)
+                if self.has_bias:
+                    # Requantize bias to accumulator scale "permanently"
+                    linear_quantize_clamp(self.wrapped_module.bias.data, self.accum_scale.squeeze(), 0,
+                                          self.accum_min_q_val, self.accum_max_q_val, inplace=True)
+        else:
+            self.accum_scale = get_accum_scale(input_q)
             if self.has_bias:
                 # Re-quantize bias to match x * w scale: b_q' = (in_scale * w_scale / b_scale) * (b_q + b_zero_point)
                 bias_requant_scale = self.accum_scale.squeeze() / self.b_scale
@@ -546,13 +660,13 @@ class RangeLinearQuantParamLayerWrapper(RangeLinearQuantWrapper):
         # to the input and weights and pass those to the wrapped model. Functionally, since at this point we're
         # dealing solely with integer values, the results are the same either way.
 
-        if self.mode != LinearQuantMode.SYMMETRIC and not self.is_simulated_quant_weight_shifted:
+        if self.output_quant_settings.quant_mode != LinearQuantMode.SYMMETRIC and not self.is_simulated_quant_weight_shifted:
             # We "store" the w_zero_point inside our wrapped module's weights to
             # improve performance on inference.
             self.wrapped_module.weight.data += self.w_zero_point
-            self.is_simulated_quant_weight_shifted.add_(1) # i.e. is_simulated_quant_weight_shifted = True
+            self.is_simulated_quant_weight_shifted.add_(1)  # i.e. is_simulated_quant_weight_shifted = True
 
-        input_q += self.in_0_zero_point
+        input_q += input_q.quant_metadata.zero_point
         accum = self.wrapped_module.forward(input_q)
         clamp(accum.data, self.accum_min_q_val, self.accum_max_q_val, inplace=True)
 
@@ -563,8 +677,10 @@ class RangeLinearQuantParamLayerWrapper(RangeLinearQuantWrapper):
             return self.output_scale, self.output_zero_point
 
         y_f = accumulator / self.accum_scale
-        return _get_quant_params_from_tensor(y_f, self.num_bits_acts, self.mode, clip=self.clip_acts,
-                                             num_stds=self.clip_n_stds,
+        q_set = self.output_quant_settings
+        return _get_quant_params_from_tensor(y_f, q_set.num_bits, q_set.quant_mode,
+                                             clip=q_set.clip_mode, num_stds=q_set.clip_n_stds,
+                                             half_range=q_set.clip_half_range,
                                              scale_approx_mult_bits=self.scale_approx_mult_bits)
 
     def get_accum_to_output_re_quantization_params(self, output_scale, output_zero_point):
@@ -574,28 +690,13 @@ class RangeLinearQuantParamLayerWrapper(RangeLinearQuantWrapper):
         return requant_scale, output_zero_point
 
     def extra_repr(self):
-        tmpstr = 'mode={0}, '.format(str(self.mode).split('.')[1])
-        tmpstr += 'num_bits_acts={0}, num_bits_params={1}, num_bits_accum={2}, '.format(self.num_bits_acts,
-                                                                                        self.num_bits_params,
-                                                                                        self.num_bits_accum)
-        tmpstr += 'clip_acts={0}, '.format(_enum_to_str(self.clip_acts))
-        if self.clip_acts == ClipMode.N_STD:
-            tmpstr += 'num_stds={} '.format(self.clip_n_stds)
-        tmpstr += 'per_channel_wts={}, scale_approx_mult_bits={}'.format(self.per_channel_wts,
-                                                                         self.scale_approx_mult_bits)
-        tmpstr += '\npreset_activation_stats={0}'.format(self.preset_act_stats)
-        if self.per_channel_wts:
-            tmpstr += '\nw_scale=PerCh, w_zero_point=PerCh'
-        else:
-            tmpstr += '\nw_scale={0:.4f}, w_zero_point={1:.4f}'.format(self.w_scale.item(), self.w_zero_point.item())
-        if self.preset_act_stats:
-            tmpstr += '\nin_scale={0:.4f}, in_zero_point={1:.4f}'.format(self.in_0_scale.item(),
-                                                                         self.in_0_zero_point.item())
-            tmpstr += '\nout_scale={0:.4f}, out_zero_point={1:.4f}'.format(self.output_scale.item(),
-                                                                           self.output_zero_point.item())
-        elif self.has_bias:
-            tmpstr += '\nbase_b_scale={0:.4f}, base_b_zero_point={1:.4f}'.format(self.b_scale.item(),
-                                                                                 self.b_zero_point.item())
+        tmpstr = 'weights_quant_settings={0}\n'.format(self.wts_quant_settings)
+        tmpstr += super(RangeLinearQuantParamLayerWrapper, self).extra_repr()
+        tmpstr += '\nweights_scale={0}, weights_zero_point={1}'.format(_quant_param_to_str(self.w_scale),
+                                                                       _quant_param_to_str(self.w_zero_point))
+        if not self.preset_act_stats and self.has_bias:
+            tmpstr += '\nbase_bias_scale={0}, base_bias_zero_point={1}'.format(_quant_param_to_str(self.b_scale),
+                                                                               _quant_param_to_str(self.b_zero_point))
         return tmpstr
 
 
@@ -625,32 +726,24 @@ class RangeLinearQuantMatmulWrapper(RangeLinearQuantWrapper):
         scale_approx_mult_bits (int): See RangeLinearQuantWrapper
     """
     def __init__(self, wrapped_module, num_bits_acts, num_bits_accum=32,
-                 mode=LinearQuantMode.SYMMETRIC, clip_acts=ClipMode.NONE,  activation_stats=None,
-                 clip_n_stds=None, clip_half_range=False, scale_approx_mult_bits=None):
+                 mode=LinearQuantMode.SYMMETRIC, clip_acts=ClipMode.NONE, activation_stats=None,
+                 clip_n_stds=None, clip_half_range=False, scale_approx_mult_bits=None,
+                 input_overrides=None, inputs_quant_auto_fallback=False):
         super(RangeLinearQuantMatmulWrapper, self).__init__(wrapped_module, num_bits_acts, num_bits_accum, mode,
                                                             clip_acts, activation_stats, clip_n_stds, clip_half_range,
-                                                            scale_approx_mult_bits)
+                                                            scale_approx_mult_bits,
+                                                            input_overrides=input_overrides,
+                                                            requires_quantized_inputs=True,
+                                                            inputs_quant_auto_fallback=inputs_quant_auto_fallback)
 
         if not isinstance(wrapped_module, (distiller.modules.Matmul, distiller.modules.BatchMatmul)):
             raise ValueError(self.__class__.__name__ + ' can wrap only Matmul modules')
-        if self.preset_act_stats:
-            self.register_buffer('accum_scale', self.in_0_scale * self.in_1_scale)
-        else:
-            self.accum_scale = 1
-
-    def get_inputs_quantization_params(self, input0, input1):
-        if not self.preset_act_stats:
-            self.in_0_scale, self.in_0_zero_point = _get_quant_params_from_tensor(
-                input0, self.num_bits_acts, self.mode, clip=self.clip_acts,
-                num_stds=self.clip_n_stds, scale_approx_mult_bits=self.scale_approx_mult_bits)
-            self.in_1_scale, self.in_1_zero_point = _get_quant_params_from_tensor(
-                input0, self.num_bits_acts, self.mode, clip=self.clip_acts,
-                num_stds=self.clip_n_stds, scale_approx_mult_bits=self.scale_approx_mult_bits)
-        return [self.in_0_scale, self.in_1_scale], [self.in_0_zero_point, self.in_1_zero_point]
+        self.accum_scale = 1
 
     def quantized_forward(self, input0_q, input1_q):
-        accum = self.wrapped_module.forward(input0_q + self.in_0_zero_point,
-                                            input1_q + self.in_1_zero_point)
+        self.accum_scale = input0_q.quant_metadata.scale * input1_q.quant_metadata.scale
+        accum = self.wrapped_module.forward(input0_q + input0_q.quant_metadata.zero_point,
+                                            input1_q + input1_q.quant_metadata.zero_point)
         clamp(accum.data, self.accum_min_q_val, self.accum_max_q_val, inplace=True)
         return accum
 
@@ -659,8 +752,10 @@ class RangeLinearQuantMatmulWrapper(RangeLinearQuantWrapper):
             return self.output_scale, self.output_zero_point
 
         y_f = accumulator / self.accum_scale
-        return _get_quant_params_from_tensor(y_f, self.num_bits_acts, self.mode, clip=self.clip_acts,
-                                             num_stds=self.clip_n_stds,
+        q_set = self.output_quant_settings
+        return _get_quant_params_from_tensor(y_f, q_set.num_bits, q_set.quant_mode,
+                                             clip=q_set.clip_mode, num_stds=q_set.clip_n_stds,
+                                             half_range=q_set.clip_half_range,
                                              scale_approx_mult_bits=self.scale_approx_mult_bits)
 
     def get_accum_to_output_re_quantization_params(self, output_scale, output_zero_point):
@@ -669,25 +764,6 @@ class RangeLinearQuantMatmulWrapper(RangeLinearQuantWrapper):
             requant_scale = approx_scale_as_mult_and_shift(requant_scale, self.scale_approx_mult_bits)
         return requant_scale, output_zero_point
 
-    def extra_repr(self):
-        tmpstr = 'mode={0}, '.format(str(self.mode).split('.')[1])
-        tmpstr += 'num_bits_acts={0},  num_bits_accum={1}, '.format(self.num_bits_acts, self.num_bits_accum)
-        tmpstr += 'clip_acts={0}, '.format(_enum_to_str(self.clip_acts))
-        if self.clip_acts == ClipMode.N_STD:
-            tmpstr += 'num_stds={} '.format(self.clip_n_stds)
-        tmpstr += 'scale_approx_mult_bits={}'.format(self.scale_approx_mult_bits)
-        tmpstr += '\npreset_activation_stats={0}'.format(self.preset_act_stats)
-        if self.preset_act_stats:
-            tmpstr += '\nin_0_scale={0:.4f}, in_0_zero_point={1:.4f}'.format(self.in_0_scale.item(),
-                                                                         self.in_0_zero_point.item())
-
-            tmpstr += '\nin_1_scale={0:.4f}, in_1_zero_point={1:.4f}'.format(self.in_1_scale.item(),
-                                                                         self.in_1_zero_point.item())
-
-            tmpstr += '\nout_scale={0:.4f}, out_zero_point={1:.4f}'.format(self.output_scale.item(),
-                                                                           self.output_zero_point.item())
-        return tmpstr
-
 
 class NoStatsError(NotImplementedError):
     pass
@@ -695,7 +771,8 @@ class NoStatsError(NotImplementedError):
 
 class RangeLinearQuantConcatWrapper(RangeLinearQuantWrapper):
     def __init__(self, wrapped_module, num_bits_acts, mode=LinearQuantMode.SYMMETRIC, clip_acts=ClipMode.NONE,
-                 activation_stats=None, clip_n_stds=None, clip_half_range=False, scale_approx_mult_bits=None):
+                 activation_stats=None, clip_n_stds=None, clip_half_range=False, scale_approx_mult_bits=None,
+                 input_overrides=None, inputs_quant_auto_fallback=False):
         if not isinstance(wrapped_module, distiller.modules.Concat):
             raise ValueError(self.__class__.__name__ + ' can only wrap distiller.modules.Concat modules')
 
@@ -706,33 +783,18 @@ class RangeLinearQuantConcatWrapper(RangeLinearQuantWrapper):
         super(RangeLinearQuantConcatWrapper, self).__init__(wrapped_module, num_bits_acts, mode=mode,
                                                             clip_acts=clip_acts, activation_stats=activation_stats,
                                                             clip_n_stds=clip_n_stds, clip_half_range=clip_half_range,
-                                                            scale_approx_mult_bits=scale_approx_mult_bits)
-
-        if self.preset_act_stats:
-            # For concatenation to make sense, we need to match all the inputs' scales, so we
-            # set a re-scale factor based on the preset output scale
-            for idx, in_scale in enumerate(self.inputs_scales()):
-                requant_scale = self.output_scale / in_scale
-                if self.scale_approx_mult_bits is not None:
-                    requant_scale = approx_scale_as_mult_and_shift(requant_scale, self.scale_approx_mult_bits)
-                self.register_buffer('in_{0}_requant_scale'.format(idx), requant_scale)
-
-    def inputs_requant_scales(self):
-        if not self.preset_act_stats:
-            raise RuntimeError('Input quantization parameter iterators only available when activation stats were given')
-        for idx in range(self.num_inputs):
-            name = 'in_{0}_requant_scale'.format(idx)
-            yield getattr(self, name)
-
-    def get_inputs_quantization_params(self, *inputs):
-        return self.inputs_scales(), self.inputs_zero_points()
+                                                            scale_approx_mult_bits=scale_approx_mult_bits,
+                                                            input_overrides=input_overrides,
+                                                            requires_quantized_inputs=True,
+                                                            inputs_quant_auto_fallback=inputs_quant_auto_fallback)
 
     def quantized_forward(self, *inputs_q):
-        # Re-quantize all inputs based to the same range (the output range)
-        inputs_re_q = [linear_quantize_clamp(input_q + zp, requant_scale, self.output_zero_point,
-                                             self.acts_min_q_val, self.acts_max_q_val, inplace=False)
-                       for input_q, requant_scale, zp in zip(inputs_q, self.inputs_requant_scales(),
-                                                             self.inputs_zero_points())]
+        # For concatenation to make sense input scales need to match, so we re-quantize all inputs
+        # based on the output scale
+        inputs_re_q = [linear_quantize_clamp(input_q + input_q.quant_metadata.zero_point,
+                                             self.output_scale / input_q.quant_metadata.scale, 0.,
+                                             self.accum_min_q_val, self.accum_max_q_val, inplace=False)
+                       for input_q in inputs_q]
         return self.wrapped_module(*inputs_re_q)
 
     def get_output_quantization_params(self, accumulator):
@@ -740,12 +802,13 @@ class RangeLinearQuantConcatWrapper(RangeLinearQuantWrapper):
 
     def get_accum_to_output_re_quantization_params(self, output_scale, output_zero_point):
         # Nothing to do here, since we already re-quantized in quantized_forward prior to the actual concatenation
-        return 1., 0.
+        return 1., self.output_zero_point
 
 
 class RangeLinearQuantEltwiseAddWrapper(RangeLinearQuantWrapper):
     def __init__(self, wrapped_module, num_bits_acts, mode=LinearQuantMode.SYMMETRIC, clip_acts=ClipMode.NONE,
-                 activation_stats=None, clip_n_stds=None, clip_half_range=False, scale_approx_mult_bits=None):
+                 activation_stats=None, clip_n_stds=None, clip_half_range=False, scale_approx_mult_bits=None,
+                 input_overrides=None, inputs_quant_auto_fallback=False):
         if not isinstance(wrapped_module, distiller.modules.EltwiseAdd):
             raise ValueError(self.__class__.__name__ + ' can only wrap distiller.modules.EltwiseAdd modules')
 
@@ -755,36 +818,18 @@ class RangeLinearQuantEltwiseAddWrapper(RangeLinearQuantWrapper):
 
         super(RangeLinearQuantEltwiseAddWrapper, self).__init__(wrapped_module, num_bits_acts, mode=mode,
                                                                 clip_acts=clip_acts, activation_stats=activation_stats,
-                                                                clip_n_stds=clip_n_stds,
-                                                                clip_half_range=clip_half_range,
-                                                                scale_approx_mult_bits=scale_approx_mult_bits)
-
-        if self.preset_act_stats:
-            # For addition to make sense, all input scales must match. So we set a re-scale factor according
-            # to the preset output scale
-            requant_scales = [self.output_scale / in_scale for in_scale in self.inputs_scales()]
-            if scale_approx_mult_bits is not None:
-                requant_scales = [approx_scale_as_mult_and_shift(requant_scale, scale_approx_mult_bits)
-                                  for requant_scale in requant_scales]
-            for idx, requant_scale in enumerate(requant_scales):
-                self.register_buffer('in_{0}_requant_scale'.format(idx), requant_scale)
-
-    def inputs_requant_scales(self):
-        if not self.preset_act_stats:
-            raise RuntimeError('Input quantization parameter iterators only available when activation stats were given')
-        for idx in range(self.num_inputs):
-            name = 'in_{0}_requant_scale'.format(idx)
-            yield getattr(self, name)
-
-    def get_inputs_quantization_params(self, *inputs):
-        return self.inputs_scales(), self.inputs_zero_points()
+                                                                clip_n_stds=clip_n_stds, clip_half_range=clip_half_range,
+                                                                scale_approx_mult_bits=scale_approx_mult_bits,
+                                                                input_overrides=input_overrides,
+                                                                requires_quantized_inputs=True,
+                                                                inputs_quant_auto_fallback=inputs_quant_auto_fallback)
 
     def quantized_forward(self, *inputs_q):
-        # Re-scale inputs to the accumulator range
-        inputs_re_q = [linear_quantize_clamp(input_q + zp, requant_scale, 0,
+        # Re-scale inputs to the accumulator scale
+        inputs_re_q = [linear_quantize_clamp(input_q + input_q.quant_metadata.zero_point,
+                                             self.output_scale / input_q.quant_metadata.scale, 0,
                                              self.accum_min_q_val, self.accum_max_q_val, inplace=False)
-                       for input_q, requant_scale, zp in zip(inputs_q, self.inputs_requant_scales(),
-                                                             self.inputs_zero_points())]
+                       for input_q in inputs_q]
         accum = self.wrapped_module(*inputs_re_q)
         clamp(accum.data, self.accum_min_q_val, self.accum_max_q_val, inplace=True)
 
@@ -799,7 +844,8 @@ class RangeLinearQuantEltwiseAddWrapper(RangeLinearQuantWrapper):
 
 class RangeLinearQuantEltwiseMultWrapper(RangeLinearQuantWrapper):
     def __init__(self, wrapped_module, num_bits_acts, mode=LinearQuantMode.SYMMETRIC, clip_acts=ClipMode.NONE,
-                 activation_stats=None, clip_n_stds=None, clip_half_range=False, scale_approx_mult_bits=None):
+                 activation_stats=None, clip_n_stds=None, clip_half_range=False, scale_approx_mult_bits=None,
+                 input_overrides=None, inputs_quant_auto_fallback=False):
         if not isinstance(wrapped_module, distiller.modules.EltwiseMult):
             raise ValueError(self.__class__.__name__ + ' can only wrap distiller.modules.EltwiseMult modules')
 
@@ -809,20 +855,19 @@ class RangeLinearQuantEltwiseMultWrapper(RangeLinearQuantWrapper):
 
         super(RangeLinearQuantEltwiseMultWrapper, self).__init__(wrapped_module, num_bits_acts, mode=mode,
                                                                  clip_acts=clip_acts, activation_stats=activation_stats,
-                                                                 clip_n_stds=clip_n_stds,
-                                                                 clip_half_range=clip_half_range,
-                                                                 scale_approx_mult_bits=scale_approx_mult_bits)
-
-        if self.preset_act_stats:
-            self.register_buffer('accum_scale', reduce(lambda x, y: x * y, self.inputs_scales()))
-
-    def get_inputs_quantization_params(self, *inputs):
-        return self.inputs_scales(), self.inputs_zero_points()
+                                                                 clip_n_stds=clip_n_stds, clip_half_range=clip_half_range,
+                                                                 scale_approx_mult_bits=scale_approx_mult_bits,
+                                                                 input_overrides=input_overrides,
+                                                                 requires_quantized_inputs=True,
+                                                                 inputs_quant_auto_fallback=inputs_quant_auto_fallback)
+        self.accum_scale = 1
 
     def quantized_forward(self, *inputs_q):
-        if self.mode != LinearQuantMode.SYMMETRIC:
-            for input_q, zp in zip(inputs_q, self.inputs_zero_points()):
-                input_q += zp
+        input_scales = [input_q.quant_metadata.scale for input_q in inputs_q]
+        self.accum_scale = reduce(lambda x, y: x * y, input_scales)
+
+        for input_q in inputs_q:
+            input_q += input_q.quant_metadata.zero_point
 
         accum = self.wrapped_module(*inputs_q)
         clamp(accum.data, self.accum_min_q_val, self.accum_max_q_val, inplace=True)
@@ -839,26 +884,28 @@ class RangeLinearQuantEltwiseMultWrapper(RangeLinearQuantWrapper):
         return requant_scale, output_zero_point
 
 
-class FP16Wrapper(nn.Module):
+class FPWrapper(nn.Module):
     """
     A wrapper that replaces a module with a half precision version.
-
     Args:
         module (nn.Module): The module to be replaced.
-        convert_input (:obj:`bool`, optional): Specifies whether an input conversion
-            to fp16 is required for forward. Default: True.
-        return_fp32 (:obj:`bool`, optional): Specifies whether the output needs
+        precision (Union[str, int]): the floating point precision to use. Either 16/32/64.
+        convert_input (bool): Specifies whether an input conversion
+            to module precision is required for forward. Default: True.
+        return_fp32 (bool): Specifies whether the output needs
             to be converted back to fp32. Default: True.
     """
-    def __init__(self, module: nn.Module, convert_input=True, return_fp32=True):
-        super(FP16Wrapper, self).__init__()
-        self.wrapped_module = module.half()
+    def __init__(self, module: nn.Module, precision, convert_input=True, return_fp32=True):
+        super(FPWrapper, self).__init__()
+        precision = str(precision)
+        self.dtype = {'16': torch.float16, '32': torch.float32, '64': torch.float64}[precision]
+        self.wrapped_module = module.to(self.dtype)
         self.return_fp32 = return_fp32
-        self.convert_input_fp16 = convert_input
+        self.convert_input = convert_input
 
     def forward(self, *input):
-        if self.convert_input_fp16:
-            input = distiller.convert_tensors_recursively_to(input, dtype=torch.float16)
+        if self.convert_input:
+            input = distiller.convert_tensors_recursively_to(input, dtype=self.dtype)
 
         result = self.wrapped_module(*input)
         if self.return_fp32:
@@ -867,6 +914,11 @@ class FP16Wrapper(nn.Module):
         return result
 
 
+class FP16Wrapper(FPWrapper):
+    def __init__(self, module, convert_input=True, return_fp32=True):
+        super(FP16Wrapper, self).__init__(module, 16, convert_input, return_fp32)
+
+
 class RangeLinearEmbeddingWrapper(nn.Module):
     def __init__(self, wrapped_module, num_bits, mode=LinearQuantMode.SYMMETRIC, stats=None):
         if not isinstance(wrapped_module, nn.Embedding):
@@ -888,20 +940,69 @@ class RangeLinearEmbeddingWrapper(nn.Module):
         self.register_buffer('w_zero_point', w_zero_point.to(device))
         linear_quantize_clamp(wrapped_module.weight.data, self.w_scale, self.w_zero_point, self.min_q_val,
                               self.max_q_val, inplace=True)
-
+        self.quant_metadata = TensorQuantMetadata(self.w_scale, self.w_zero_point, self.min_q_val, self.max_q_val)
         self.wrapped_module = wrapped_module
 
     def forward(self, input):
         out_q = self.wrapped_module(input)
         out_f = linear_dequantize(out_q, self.w_scale, self.w_zero_point, inplace=True)
+        out_f.quant_metadata = self.quant_metadata
         return out_f
 
 
+class RangeLinearFakeQuantWrapper(RangeLinearQuantWrapper):
+    def __init__(self, wrapped_module, num_bits_acts, mode=LinearQuantMode.SYMMETRIC, clip_acts=ClipMode.NONE,
+                 activation_stats=None, clip_n_stds=None, clip_half_range=False, scale_approx_mult_bits=None,
+                 fpq_module=None):
+        super(RangeLinearFakeQuantWrapper, self).__init__(wrapped_module, num_bits_acts, mode=mode,
+                                                          clip_acts=clip_acts, activation_stats=activation_stats,
+                                                          clip_n_stds=clip_n_stds, clip_half_range=clip_half_range,
+                                                          scale_approx_mult_bits=scale_approx_mult_bits,
+                                                          requires_quantized_inputs=False)
+        self.fpq_module = str(fpq_module) if fpq_module else None
+        self.dtype = torch.float
+        if self.fpq_module:
+            self.dtype = {'16': torch.half, '32': torch.float, '64': torch.double}[self.fpq_module]
+            self.wrapped_module.to(self.dtype)
+
+    def quantized_forward(self, *inputs_q):
+        inputs_q = distiller.convert_tensors_recursively_to(inputs_q, dtype=self.dtype)
+        outputs = self.wrapped_module(*inputs_q)
+        return distiller.convert_tensors_recursively_to(outputs, dtype=self.dtype)
+
+    def get_output_quantization_params(self, accumulator):
+        if self.preset_act_stats:
+            return self.output_scale, self.output_zero_point
+        else:
+            q_set = self.output_quant_settings
+            return _get_quant_params_from_tensor(accumulator, q_set.num_bits, q_set.quant_mode, q_set.clip_mode,
+                                                 q_set.per_channel, q_set.clip_n_stds,q_set.clip_half_range,
+                                                 self.scale_approx_mult_bits)
+
+    def get_accum_to_output_re_quantization_params(self, output_scale, output_zero_point):
+        return output_scale, output_zero_point
+
+
 class PostTrainLinearQuantizer(Quantizer):
     """
     Applies range-based linear quantization to a model.
     This quantizer is expected to be executed at evaluation only, on a pre-trained model
-    Currently, the following Modules are supported: torch.nn.Conv2d, torch.nn.Linear
+
+    The following modules / operations have dedicated implementations which consider quantization:
+      * torch.nn.Conv2d/Conv3d
+      * torch.nn.Linear
+      * torch.nn.Embedding
+      * distiller.modules.Concat
+      * distiller.modules.EltwiseAdd
+      * distiller.modules.EltwiseMult
+      * distiller.modules.Matmul
+      * distiller.modules.BatchMatmul
+    An existing module will need likely need to be modified to use the 'distiller.modules.*' modules. This needs to
+    be done BEFORE creating the quantizer. See the docs for more details:
+    https://nervanasystems.github.io/distiller/prepare_model_quant.html
+
+    Any leaf module not in the list above will be "fake-quantized". That is - the floating-point module will be
+    executed (FP64/32/16 can be specified with the fpq_module argument), and its output will be quantized.
 
     Args:
         model (torch.nn.Module): Model to be quantized
@@ -916,23 +1017,34 @@ class PostTrainLinearQuantizer(Quantizer):
             The dict should be in the format exported by distiller.data_loggers.QuantCalibrationStatsCollector.
             If None then parameters are calculated dynamically.
         fp16 (bool): Set to True to convert modules to half precision.
+            WARNING - this argument is deprecated, use instead the argument `fpq_module`
         clip_n_stds (float): When clip_acts == ClipMode.N_STD, this is the number of standard deviations to use
+        clip_half_range (bool): When clip_acts is
         scale_approx_mult_bits (int): If not None, scale factors will be approximated using an integer multiplication
             followed by a bit-wise shift. This eliminates floating-point scale factors, replacing them with integer
             calculations.
             If None, scale factors will be kept in their original FP32 values.
+        inputs_quant_auto_fallback (bool): Enabled by default.
+            See <distiller_root>/examples/post_train_quant/resnet18_imagenet_post_train_input_overrides.yaml
+            For details what this does and how to override it.
+        fpq_module (Union[int, str]): use the modules in floating point mode and only quantize their outputs.
+            takes the values (16, 32, 64) only, this will use RangeLinearFakeQuantWrapper.
     Note:
-        If fp16 is set to True, all the layers (except those overridden in `overrides`) will be converted
-        to half precision, regardless of bits_activations/parameters/accum.
+        If fpq_module is set, all the layers (except those overridden in `overrides`) will be converted
+        to the set floating point precision, regardless of bits_activations/parameters/accum.
     """
     def __init__(self, model, bits_activations=8, bits_parameters=8, bits_accum=32,
                  overrides=None, mode=LinearQuantMode.SYMMETRIC, clip_acts=ClipMode.NONE,
-                 per_channel_wts=False, model_activation_stats=None, fp16=False, clip_n_stds=None,
-                 scale_approx_mult_bits=None):
+                 per_channel_wts=False, model_activation_stats=None, fp16=False,
+                 clip_n_stds=None, clip_half_range=False,
+                 scale_approx_mult_bits=None, inputs_quant_auto_fallback=True,
+                 fpq_module=None):
+        overrides_bkp = deepcopy(overrides)
         super(PostTrainLinearQuantizer, self).__init__(model, bits_activations=bits_activations,
                                                        bits_weights=bits_parameters, bits_bias=bits_accum,
                                                        overrides=overrides, train_with_fp_copy=False)
-
+        if fp16 and str(fpq_module) not in ('16', 'None'):
+            raise ValueError('Conflict - fp16 set to true and fpq_module set to other than 16.')
         mode = verify_quant_mode(mode)
         clip_acts = verify_clip_mode(clip_acts)
         if clip_acts == ClipMode.N_STD and clip_n_stds is None:
@@ -964,39 +1076,71 @@ class PostTrainLinearQuantizer(Quantizer):
                                                     'mode': str(mode).split('.')[1],
                                                     'clip_acts': _enum_to_str(clip_acts),
                                                     'clip_n_stds': clip_n_stds,
+                                                    'clip_half_range': clip_half_range,
                                                     'per_channel_wts': per_channel_wts,
-                                                    'fp16': fp16,
-                                                    'scale_approx_mult_bits': scale_approx_mult_bits}}
+                                                    'scale_approx_mult_bits': scale_approx_mult_bits,
+                                                    'inputs_quant_auto_fallback': inputs_quant_auto_fallback,
+                                                    'fpq_module': fpq_module,
+                                                    'model_activation_stats': model_activation_stats,
+                                                    'overrides': overrides_bkp}}
 
         def replace_param_layer(module, name, qbits_map, per_channel_wts=per_channel_wts,
                                 mode=mode, fp16=fp16, scale_approx_mult_bits=scale_approx_mult_bits,
-                                clip_acts=clip_acts, clip_half_range=False, clip_n_stds=clip_n_stds):
+                                clip_acts=clip_acts, clip_n_stds=clip_n_stds, clip_half_range=clip_half_range,
+                                input_overrides=None, fpq_module=fpq_module, fake=False):
             if fp16:
-                return FP16Wrapper(module)
-
-            # TODO: Try auto-detecting when clip_half_range is needed
-            #  instead of having the user pass it as a parameter (same for replace_non_param_layer)
+                warnings.warn("Argument 'fp16' is deprecated. Please use 'fpq_module'(=16/32/64) argument.",
+                              DeprecationWarning)
+                fpq_module = fpq_module or 16
             norm_name = distiller.utils.normalize_module_name(name)
             clip_acts = verify_clip_mode(clip_acts)
+            if fpq_module:
+                if not fake:
+                    return FPWrapper(module, fpq_module)
+                else:
+                    return RangeLinearFakeQuantWrapper(module, qbits_map[name].acts, mode=mode, clip_acts=clip_acts,
+                                                       activation_stats=self.model_activation_stats.get(norm_name,
+                                                                                                        None),
+                                                       clip_n_stds=clip_n_stds, clip_half_range=clip_half_range,
+                                                       scale_approx_mult_bits=scale_approx_mult_bits,
+                                                       fpq_module=fpq_module)
+
             return RangeLinearQuantParamLayerWrapper(module, qbits_map[name].acts, qbits_map[name].wts,
                                                      num_bits_accum=self.bits_accum, mode=mode, clip_acts=clip_acts,
                                                      per_channel_wts=per_channel_wts,
                                                      activation_stats=self.model_activation_stats.get(norm_name, None),
                                                      clip_n_stds=clip_n_stds, clip_half_range=clip_half_range,
-                                                     scale_approx_mult_bits=scale_approx_mult_bits)
+                                                     scale_approx_mult_bits=scale_approx_mult_bits,
+                                                     input_overrides=input_overrides,
+                                                     inputs_quant_auto_fallback=inputs_quant_auto_fallback)
 
         def replace_non_param_layer(wrapper_type, module, name, qbits_map, fp16=fp16,
                                     scale_approx_mult_bits=scale_approx_mult_bits,
-                                    clip_acts=clip_acts, clip_n_stds=clip_n_stds, clip_half_range=False):
-            if fp16:
-                return FP16Wrapper(module)
+                                    clip_acts=clip_acts, clip_n_stds=clip_n_stds, clip_half_range=clip_half_range,
+                                    input_overrides=None, inputs_quant_auto_fallback=inputs_quant_auto_fallback,
+                                    fpq_module=fpq_module, fake=False):
             norm_name = distiller.utils.normalize_module_name(name)
             clip_acts = verify_clip_mode(clip_acts)
+            if fp16:
+                warnings.warn("Argument 'fp16' is deprecated. Please use 'fpq_module'(=16/32/64) argument.",
+                              DeprecationWarning)
+                fpq_module = fpq_module or 16
+            if fpq_module:
+                if fake:
+                    return RangeLinearFakeQuantWrapper(module, qbits_map[name].acts, mode=mode, clip_acts=clip_acts,
+                                                       activation_stats=self.model_activation_stats.get(norm_name, None),
+                                                       clip_n_stds=clip_n_stds, clip_half_range=clip_half_range,
+                                                       scale_approx_mult_bits=scale_approx_mult_bits,
+                                                       fpq_module=fpq_module)
+                else:
+                    return FPWrapper(module, fpq_module)
             try:
                 return wrapper_type(module, qbits_map[name].acts, mode=mode, clip_acts=clip_acts,
                                     activation_stats=self.model_activation_stats.get(norm_name, None),
-                                    clip_n_stds=clip_n_stds, clip_half_range=clip_half_range,
-                                    scale_approx_mult_bits=scale_approx_mult_bits)
+                                    clip_n_stds=clip_n_stds,  clip_half_range=clip_half_range,
+                                    scale_approx_mult_bits=scale_approx_mult_bits,
+                                    input_overrides=input_overrides,
+                                    inputs_quant_auto_fallback=inputs_quant_auto_fallback)
             except NoStatsError:
                 msglogger.warning('WARNING: {0} - quantization of {1} without stats not supported. '
                                   'Keeping the original FP32 module'.format(name, module.__class__.__name__))
@@ -1009,6 +1153,31 @@ class PostTrainLinearQuantizer(Quantizer):
             return RangeLinearEmbeddingWrapper(module, qbits_map[name].wts, mode=mode,
                                                stats=self.model_activation_stats.get(norm_name, None))
 
+        def replace_fake_quant(module, name, qbits_map, fp16=fp16,
+                               clip_acts=clip_acts, clip_n_stds=clip_n_stds, clip_half_range=clip_half_range,
+                               scale_approx_mult_bits=scale_approx_mult_bits, fpq_module=fpq_module, fake=True,
+                               make_identity=False):
+            if isinstance(module, (nn.ReLU, nn.ReLU6)) and make_identity:
+                named_modules = OrderedDict(self.model.named_modules())
+                pred = self.adjacency_map[name].predecessors[0].name
+                if isinstance(named_modules[pred], RangeLinearQuantWrapper):
+                    return nn.Identity()
+            norm_name = distiller.utils.normalize_module_name(name)
+            clip_acts = verify_clip_mode(clip_acts)
+            if distiller.has_children(module):
+                return module
+            if fp16:
+                warnings.warn("Argument 'fp16' is deprecated. Please use 'fpq_module'(=16/32/64) argument.",
+                              DeprecationWarning)
+                fpq_module = 16
+            if not fake:
+                return FPWrapper(module, fpq_module)
+            return RangeLinearFakeQuantWrapper(module, qbits_map[name].acts, mode=mode, clip_acts=clip_acts,
+                                               activation_stats=self.model_activation_stats.get(norm_name, None),
+                                               clip_n_stds=clip_n_stds,  clip_half_range=clip_half_range,
+                                               scale_approx_mult_bits=scale_approx_mult_bits,
+                                               fpq_module=fpq_module)
+
         self.clip_acts = clip_acts
         self.clip_n_stds = clip_n_stds
         self.model_activation_stats = model_activation_stats or {}
@@ -1040,6 +1209,8 @@ class PostTrainLinearQuantizer(Quantizer):
         self.replacement_factory[distiller.modules.BatchMatmul] = factory_matmul
         self.replacement_factory[nn.Embedding] = replace_embedding
 
+        self.default_repalcement_fn = replace_fake_quant
+
         save_dir = msglogger.logdir if hasattr(msglogger, 'logdir') else '.'
         self.save_per_layer_parameters(save_dir)
 
@@ -1067,7 +1238,8 @@ class PostTrainLinearQuantizer(Quantizer):
                        model_activation_stats=args.qe_stats_file,
                        clip_n_stds=args.qe_clip_n_stds,
                        scale_approx_mult_bits=args.qe_scale_approx_bits,
-                       overrides=overrides)
+                       overrides=overrides,
+                       inputs_quant_auto_fallback=True)
 
     def save_per_layer_parameters(self, save_dir=''):
         defaults = OrderedDict(self.model.quantizer_metadata['params'])
@@ -1110,6 +1282,7 @@ class PostTrainLinearQuantizer(Quantizer):
         if not self.has_bidi_distiller_lstm:
             self._apply_bn_folding(dummy_input)
             self._apply_activation_stats_fusions()
+            self._apply_fuse_relu()
         else:
             self._apply_bidi_distiller_lstm_stats_fusion()
 
@@ -1117,6 +1290,14 @@ class PostTrainLinearQuantizer(Quantizer):
         save_path = os.path.join(save_dir, 'quant_stats_after_prepare_model.yaml')
         distiller.yaml_ordered_save(save_path, self.model_activation_stats)
         msglogger.info('Updated stats saved to ' + save_path)
+        # for module_name, override in self.module_overrides_map.items():
+        #     # Hack to bypass Quantizer pre-override check -
+        #     # Quantizer class checks `qbit.acts` and `qbit.wts` before applying overrides
+        #     # but since fp16 doesn't act as an intN - we need to override these
+        #     # tensors to bypass the check
+        #     if (override.get('fp16', False) or override.get('fpq_module', False)) and \
+        #          not override.get('fake', False):
+        #         self.module_qbits_map[module_name] = QBits('fp', None, None)
 
     def _clip_stats(self, entry, min_val, max_val):
         if entry['max'] < min_val:
@@ -1131,7 +1312,7 @@ class PostTrainLinearQuantizer(Quantizer):
 
     def _apply_bn_folding(self, dummy_input):
         msglogger.info('Applying batch-norm folding ahead of post-training quantization')
-        mt.fold_batch_norms_inference(self.model, adjacency_map=self.adjacency_map)
+        mt.fold_batch_norms(self.model, adjacency_map=self.adjacency_map, inference=True)
 
         # After BN folding model need to re-generate the adjacency map
         summary_graph = distiller.SummaryGraph(self.model, dummy_input)
@@ -1145,16 +1326,20 @@ class PostTrainLinearQuantizer(Quantizer):
         named_modules = OrderedDict(self.model.named_modules())
         model_stats = self.model_activation_stats
         for n, m in named_modules.items():
-            try:
-                # Look for the mark left by distiller.model_transforms.fold_batch_norms
-                folded_bn_module = distiller.normalize_module_name(m.fused_modules[0])
-                # Propagate the output stats of the folded BN module to this module
-                # If stats were collected after folding was applied, then stats for the BN module won't exist,
-                # in which case we just move on
-                model_stats[distiller.normalize_module_name(n)]['output'] = model_stats.pop(folded_bn_module)['output']
-                msglogger.debug('  {} --> {}'.format(folded_bn_module, n))
-            except (AttributeError, KeyError):
+            # Look for the mark left by distiller.model_transforms.fold_batch_norms
+            fused_modules = getattr(m, 'fused_modules', None)
+            if fused_modules is None:
                 continue
+            folded_bn_module = distiller.normalize_module_name(fused_modules[0])
+
+            # Propagate the output stats of the folded BN module to this module
+            # If stats were collected after folding was applied, then stats for the BN module won't exist,
+            # in which case we just move on
+            folded_bn_stats = model_stats.pop(folded_bn_module, None)
+            if folded_bn_stats is None:
+                continue
+            model_stats[distiller.normalize_module_name(n)]['output'] = folded_bn_stats['output']
+            msglogger.debug('  {} --> {}'.format(folded_bn_module, n))
 
     def _apply_activation_stats_fusions(self):
         # Now we look for certain "fusions" of layers and activations
@@ -1169,7 +1354,8 @@ class PostTrainLinearQuantizer(Quantizer):
         named_modules = OrderedDict(self.model.named_modules())
         model_stats = self.model_activation_stats
         for n, m in named_modules.items():
-            if distiller.has_children(m) or n not in self.adjacency_map or len(self.adjacency_map[n].successors) != 1:
+            if (distiller.has_children(m) and not isinstance(m, SimulatedFoldedBatchNorm) )\
+                    or n not in self.adjacency_map or len(self.adjacency_map[n].successors) != 1:
                 continue
             successor = self.adjacency_map[n].successors[0]
             n = distiller.normalize_module_name(n)
@@ -1192,23 +1378,53 @@ class PostTrainLinearQuantizer(Quantizer):
                     succ_type = 'Sigmoid'
                 succ_stats = None
 
+            # Set the clipping values
             if succ_type == 'Relu':
                 # ReLU zeros out all negative values, so there's no need to quantize them
-                msglogger.debug('  Module {} followed by Relu, updating stats'.format(n))
-                if succ_stats is not None:
-                    m_stats['output'] = deepcopy(succ_stats['output'])
-                    succ_stats['inputs'][0] = deepcopy(succ_stats['output'])
-                else:
-                    msglogger.debug("    Relu op not a module or post-split, can't update mean and std".format(n))
-                    self._clip_stats(m_stats['output'], 0., m_stats['output']['max'])
+                min_val = 0.
+                max_val = m_stats['output']['max']
             elif succ_type == 'Sigmoid' or succ_type == 'Tanh':
                 # Tanh / Sigmoid saturate at ~4 / ~6 respectively. No need to quantize their inputs outside
                 # of these ranges
-                msglogger.debug('  Module {} followed by {}, updating stats'.format(n, succ_type))
-                sat_val = 4. if succ_type == 'Tanh' else 6.
-                self._clip_stats(m_stats['output'], -sat_val, sat_val)
-                if succ_stats is not None:
-                    succ_stats['inputs'][0] = deepcopy(m_stats['output'])
+                max_val = 4. if succ_type == 'Tanh' else 6.
+                min_val = -max_val
+            elif isinstance(named_modules.get(successor.name, None), nn.ReLU6):
+                succ_type = 'ReLU6'
+                # ReLU zeros out all negative values, so there's no need to quantize them
+                min_val = 0.
+                max_val = min(m_stats['output']['max'], 6)
+            else:
+                continue
+
+            # Clip the stats
+            msglogger.debug('  Module {} followed by {}, updating stats'.format(n, succ_type))
+            self._clip_stats(m_stats['output'], min_val, max_val)
+            if succ_stats is not None:
+                succ_stats['inputs'][0] = deepcopy(m_stats['output'])
+
+    def _apply_fuse_relu(self):
+        """Fuses ReLU layers to the linear layers before them."""
+        model_overrides = self.module_overrides_map
+        qbits_map = self.module_qbits_map
+        named_modules = dict(self.model.named_modules())
+        for n, m in named_modules.items():
+            # Don't fuse if the module isn't quantized:
+            qbits = qbits_map.get(n, QBits(None, None, None))
+            if qbits.acts is None and qbits.wts is None:
+                continue
+            if (distiller.has_children(m) and not isinstance(m, SimulatedFoldedBatchNorm))\
+                    or n not in self.adjacency_map or len(self.adjacency_map[n].successors) != 1:
+                continue
+            successor = self.adjacency_map[n].successors[0]
+            successor_module = named_modules.get(successor.name, None)
+            # Add half range clipping to module overrides
+            m_override = model_overrides.get(n, OrderedDict())
+            model_overrides[n] = m_override
+            if successor.name in named_modules and isinstance(successor_module, (nn.ReLU, nn.ReLU6)):
+                m_override['clip_half_range'] = True
+                m_override = model_overrides.get(successor.name, OrderedDict())
+                m_override['make_identity'] = True
+                model_overrides[successor.name] = m_override
 
     def _apply_bidi_distiller_lstm_stats_fusion(self):
         distiller_lstm_cells = [n for n, m in self.model.named_modules() if
@@ -1220,6 +1436,18 @@ class PostTrainLinearQuantizer(Quantizer):
             sat_val = 6.
             self._clip_stats(self.model_activation_stats[name]['output'], -sat_val, sat_val)
 
+    def _post_prepare_model(self):
+        if isinstance(self.model, nn.DataParallel):
+            # We restore the buffers to the master-GPU of the modules:
+            device = self.model.src_device_obj
+            m = self.model.module
+            for param in m.parameters():
+                param.data = param.data.to(device)
+            for buffer in m.buffers():
+                buffer.data = buffer.data.to(device)
+
+
+
 
 ###############################################################################
 # Quantization-aware training
diff --git a/distiller/summary_graph.py b/distiller/summary_graph.py
index 5df0c6b6634f3ff3994d0be4d4666b8b4bb71f28..5c951f9b79d0763005de80c91406b87de77ff6ab 100755
--- a/distiller/summary_graph.py
+++ b/distiller/summary_graph.py
@@ -72,6 +72,10 @@ class SummaryGraph(object):
 
     def __init__(self, model, dummy_input, apply_scope_name_workarounds=True):
         self._src_model = model
+        self._named_modules = OrderedDict(model.named_modules())
+        self._adj_map = None
+        self._layers_topological_order = None
+        self._top_level_ops = set()
         model_clone = distiller.make_non_parallel_copy(model)
 
         # Switch all instances of torch.nn.ModuleList in the model to our DistillerModuleList
@@ -82,6 +86,7 @@ class SummaryGraph(object):
             
             device = distiller.model_device(model_clone)
             dummy_input = distiller.convert_tensors_recursively_to(dummy_input, device=device)
+            self.dummy_input = dummy_input
             trace, _ = jit.get_trace_graph(model_clone, dummy_input, _force_outplace=True)
 
             # As of PyTorch 1.1.0, ONNX trace optimization has two issues that result in incorrect scope names
@@ -185,7 +190,9 @@ class SummaryGraph(object):
                 # so we "unroll" them.
                 same_module_cnt = len(self.module_ops_map[module_name])
                 if same_module_cnt:
-                    new_op['name'] += "__" + str(same_module_cnt)
+                    # TODO: Was this meant to be applied only to 'top_level_ops'? Also, it's not
+                    #       applied to the first module that had the same name
+                    new_op['name'] += "_%s_%d" % (new_op['type'], same_module_cnt)
                 self.module_ops_map[module_name].append(new_op['name'])
 
                 # Finally we register the new op in the ops collection
@@ -506,6 +513,13 @@ class SummaryGraph(object):
                 self._src_model, normalized_layer_name)
             yield sgraph_layer_name, param_name, param
 
+    def _dedicated_module_check(self, n, dedicated_modules_only=False):
+        if not dedicated_modules_only:
+            return True
+        module_name = self.ops[n]['module-name']
+        module = self._named_modules[module_name]
+        return len(self.module_ops_map[module_name]) == 1 and not distiller.has_children(module)
+
     def adjacency_map(self, dedicated_modules_only=False):
         """Returns a mapping from each op in the graph to its immediate predecessors and successors.
 
@@ -519,35 +533,107 @@ class SummaryGraph(object):
               associated with a dedicated module within the underlying model. Examples of this will be
               functional calls, such as "F.relu()", and tensor operations, such as "t3 = t1 + t2".
         """
+        if self._adj_map and not dedicated_modules_only:
+            return self._adj_map
         adj_map = OrderedDict()
-        named_modules = OrderedDict(self._src_model.named_modules())
 
         for op_name, op in self.ops.items():
-            def dedicated_module_check(n):
-                if not dedicated_modules_only:
-                    return True
-                module_name = self.ops[n]['module-name']
-                module = named_modules[module_name]
-                return len(self.module_ops_map[module_name]) == 1 and not distiller.has_children(module)
 
             def op_meta(n):
                 return OpSimpleMetadata(distiller.denormalize_module_name(self._src_model, n), self.ops[n]['type'])
 
-            if not dedicated_module_check(op_name):
+            if not self._dedicated_module_check(op_name, dedicated_modules_only):
                 continue
 
             entry = AdjacentsEntry(op_meta(op_name))
             # Find the immediate preceding and succeeding modules. Depth of 1 gets us the
             # input and output tensors, depth of 2 gets the actual modules
             entry.predecessors = [op_meta(n) for n in self.predecessors(op, 2, denorm_names=False)
-                                  if dedicated_module_check(n)]
+                                  if self._dedicated_module_check(n, dedicated_modules_only)]
             entry.successors = [op_meta(n) for n in self.successors(op, 2, denorm_names=False)
-                                if dedicated_module_check(n)]
+                                if self._dedicated_module_check(n, dedicated_modules_only)]
 
             adj_map[entry.op_meta.name] = entry
-
+        self._adj_map = adj_map
         return adj_map
 
+    def layers_topological_order(self, recurrent=False):
+        """
+        Prepares an ordered list of layers to quantize sequentially. This list has all the layers ordered by their
+        topological order in the graph.
+        Args:
+            recurrent (bool): indication on whether the model might have recurrent connections.
+        """
+        if self._layers_topological_order:
+            return self._layers_topological_order
+        adj_map = self.adjacency_map()
+        ranked_ops = OrderedDict([(k, _OpRank(v, 0)) for k, v in adj_map.items()])
+
+        def _recurrent_ancestor(ranked_ops_dict, dest_op_name, src_op_name):
+            def _is_descendant(parent_op_name, dest_op_name):
+                successors_names = [op.name for op in adj_map[parent_op_name].successors]
+                if dest_op_name in successors_names:
+                    return True
+                for succ_name in successors_names:
+                    if _is_descendant(succ_name, dest_op_name):
+                        return True
+                return False
+
+            return _is_descendant(dest_op_name, src_op_name) and \
+                   (0 < ranked_ops_dict[dest_op_name].rank < ranked_ops_dict[src_op_name].rank)
+
+        def rank_op(ranked_ops_dict, op_name, rank):
+            ranked_ops_dict[op_name].rank = rank
+            for child_op in adj_map[op_name].successors:
+                # In recurrent models: if a successor is also an ancestor - we don't increment its rank.
+                if not recurrent or not _recurrent_ancestor(ranked_ops_dict, child_op.name, op_name):
+                    rank_op(ranked_ops_dict, child_op.name, ranked_ops_dict[op_name].rank + 1)
+
+        roots = [k for k, v in adj_map.items() if len(v.predecessors) == 0]
+        for root_op_name in roots:
+            rank_op(ranked_ops, root_op_name, 0)
+
+        # Take only the modules from the original model
+        module_dict = dict(self._src_model.named_modules())
+        ret = sorted([k for k in ranked_ops.keys() if k in module_dict],
+                     key=lambda k: ranked_ops[k].rank)
+        # Check that only the actual roots have a rank of 0
+        assert {k for k in ret if ranked_ops[k].rank == 0} <= set(roots)
+        self._layers_topological_order = ret
+        return ret
+
+    def top_level_ops(self):
+        if self._top_level_ops:
+            return self._top_level_ops
+        for op_name in self.ops:
+            if not self.predecessors(op_name, 1):
+                self._top_level_ops.add(op_name)
+        return self._top_level_ops
+
+    def missing_modules(self):
+        """
+        Returns a list of ops that aren't registered as modules.
+        """
+        return [op_name for op_name in self.adjacency_map()
+                if not self._dedicated_module_check(op_name, True)]
+
+
+class _OpRank:
+    def __init__(self, adj_entry, rank=None):
+        self.adj_entry = adj_entry
+        self._rank = rank or 0
+
+    @property
+    def rank(self):
+        return self._rank
+
+    @rank.setter
+    def rank(self, val):
+        self._rank = max(val, self._rank)
+
+    def __repr__(self):
+        return '_OpRank(\'%s\' | %d)' % (self.adj_entry.op_meta.name, self.rank)
+
 
 class OpSimpleMetadata(object):
     def __init__(self, name, type):
diff --git a/distiller/utils.py b/distiller/utils.py
index d1c1feba9b825eab9ba81ef2b229b0da925d8d65..3b342b57a5d3181436986984e2e401e49fba14fc 100755
--- a/distiller/utils.py
+++ b/distiller/utils.py
@@ -739,3 +739,35 @@ def convert_tensors_recursively_to(val, *args, **kwargs):
 
     return val
 
+
+# TODO: Is this needed?
+def model_setattr(model, attr_name, val, register=False):
+    """
+    Sets attribute of a model, through the entire hierarchy.
+    Args:
+        model (nn.Module): the model.
+        attr_name (str): the attribute name as shown by model.named_<parameters/modules/buffers>()
+        val: the value of the attribute
+        register (bool): if True - register_buffer(val) if val is a torch.Tensor and
+          register_parameter(val) if it's an nn.Parameter.
+    """
+    def split_name(name):
+        if '.' in name:
+            return name.rsplit('.', 1)
+        else:
+            return '', name
+    modules_dict = OrderedDict(model.named_modules())
+    lowest_depth_container_name, lowest_depth_attr_name = split_name(attr_name)
+    while lowest_depth_container_name and lowest_depth_container_name not in modules_dict:
+        container_name, attr = split_name(lowest_depth_container_name)
+        lowest_depth_container_name = container_name
+        lowest_depth_attr_name = '%s%s' % (attr, lowest_depth_attr_name)
+    lowest_depth_container = modules_dict[lowest_depth_container_name]  # type: nn.Module
+
+    if register and torch.is_tensor(val):
+        if isinstance(val, nn.Parameter):
+            lowest_depth_container.register_parameter(lowest_depth_attr_name, val)
+        else:
+            lowest_depth_container.register_buffer(lowest_depth_attr_name, val)
+    else:
+        setattr(lowest_depth_container, lowest_depth_attr_name, val)
diff --git a/examples/quantization/post_train_quant/resnet18_imagenet_post_train_input_overrides.yaml b/examples/quantization/post_train_quant/resnet18_imagenet_post_train_input_overrides.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..7427dcdd625c8efc42a9ddae5b005096dd8b31f1
--- /dev/null
+++ b/examples/quantization/post_train_quant/resnet18_imagenet_post_train_input_overrides.yaml
@@ -0,0 +1,75 @@
+# Specifying quantization settings for layer INPUTS
+# -------------------------------------------------
+#
+# This configuration file demonstrates how to override the default behavior of PostTrainLinearQuantizer
+# with regards to quantizing inputs of layers.
+#
+# THIS APPLIES ONLY TO PostTraingLinearQuantizer
+#
+# 1. By default, settings for quantizing activations are applied to OUTPUTs of layers
+#
+# 2. When the output of a layer is quantized, the quantization settings (aka metadata) are attached to
+#    the tensor.
+# 
+# 3. Certain layers expect quantized inputs. That is - they expect input tensors that have quantization
+#    metadata attached.
+# 
+# 4. But, there are cases where this metadata will not be available, such as:
+#     4.1 "Root" inputs of the model (which are not the output of any layer)
+#     4.2 When two quantized layers are separated by one or more non-quantized operations. 
+#         This can happen if the user explicitly requested that these operations not be quantized, or if
+#         they are operations that can't be detected by the quantizer.
+# 
+# 5. In these cases, the default behavior of the quantizer is to quantize the layer inputs according
+#    to the settings that were provided for the layer's outputs (see item 1).
+#    This behavior is controlled by the 'inputs_quant_auto_fallback' parameter. By default it is set
+#    to True.
+#
+# 6. One can override this default behaviour and explicitly set settings for inputs. Below we show a
+#    couple of examples for this.
+#    Note that when the input tensor does have quantization metadata attached, specifying explicit
+#    settings isn't allowed, and an error will be raised. This is done in order to maintain maximum
+#    consistency with quantization settings flowing through the model.
+
+quantizers:
+  post_train_quantizer:
+    class: PostTrainLinearQuantizer
+    bits_activations: 8
+    bits_parameters: 8
+    bits_accum: 32
+    mode: ASYMMETRIC_UNSIGNED
+    # Path to stats file assuming this is being invoked from the 'classifier_compression' example directory
+    model_activation_stats: ../quantization/post_train_quant/stats/resnet18_quant_stats.yaml
+    per_channel_wts: True
+    clip_acts: AVG
+
+    # If this is set, when a quantized input is required but the tensor doesn't contain quantization metadata,
+    # the input will be quantized according to the module's output settings
+    inputs_quant_auto_fallback: False
+
+    overrides:
+      conv1:
+        # The input to the first layer in the model will never have quantization metadata (item 4.1 above)
+        # Since layers could have multiple inputs, we specify the index of the required input as a key
+        input_overrides:
+          0:
+            # Shorthand to take the quantization settings of the output (ignores any other settings)
+            from_outputs: True
+
+      fc:
+        # Here we show an example of mixing output overrides and input overrides
+        
+        # We don't clip the output of the last layer
+        clip_acts: None
+
+        # In ResNet, the FC layer has a view op before, which kills the quantization metadata (item 4.2 above).
+        # So we override.
+        # But here we don't want to take the settings from the output, where we just set clip_acts to None.
+        # So we specify clip_acts explicitly.
+        input_overrides:
+          0:
+            # Example of setting the actual value. Applicable only if 'from_outputs' isn't set.
+            # The following keys are supported: 'bits_activations', 'mode', 'clip_acts', 'clip_n_stds'
+            # Any key not explicitly set will default to the output setting
+            clip_acts: AVG
+
diff --git a/tests/full_flow_tests.py b/tests/full_flow_tests.py
index ef466198d1162a63fe8051224fbd92c9808be05e..d8617b7f78b263690b058714cd27c251a61d47e1 100755
--- a/tests/full_flow_tests.py
+++ b/tests/full_flow_tests.py
@@ -123,7 +123,7 @@ test_configs = [
     TestConfig('--arch simplenet_cifar --epochs 2', DS_CIFAR, accuracy_checker, [44.460, 91.230]),
     TestConfig('-a resnet20_cifar --resume {0} --quantize-eval --evaluate --qe-clip-acts avg --qe-no-clip-layers {1}'.
                format(os.path.join(examples_root, 'ssl', 'checkpoints', 'checkpoint_trained_dense.pth.tar'), 'fc'),
-               DS_CIFAR, accuracy_checker, [91.64, 99.63]),
+               DS_CIFAR, accuracy_checker, [91.57, 99.62]),
     TestConfig('-a preact_resnet20_cifar --epochs 2 --compress {0}'.
                format(os.path.join('full_flow_tests', 'preact_resnet20_cifar_pact_test.yaml')),
                DS_CIFAR, accuracy_checker, [44.370, 89.640]),
diff --git a/tests/test_model_transforms.py b/tests/test_model_transforms.py
index 69a4fb0979392f3dbda64d8eb02e2ff38532ace2..9ac33534ef379b67ad9dffd77e1a7d9ebfeb0398 100644
--- a/tests/test_model_transforms.py
+++ b/tests/test_model_transforms.py
@@ -225,7 +225,7 @@ def test_fuse_modules_with_pre_exist_adj_map():
 )
 def test_fold_batch_norms_inference_no_fold(model, input_shape):
     orig_model = deepcopy(model)
-    folded_model = mt.fold_batch_norms_inference(model, dummy_input=torch.randn(input_shape))
+    folded_model = mt.fold_batch_norms(model, dummy_input=torch.randn(input_shape), inference=True)
     for (n_orig, m_orig), (n_folded, m_folded) in zip(orig_model.named_modules(), folded_model.named_modules()):
         assert n_folded == n_orig
         assert type(m_folded) == type(m_orig)
@@ -255,7 +255,7 @@ def test_fold_batch_norms_inference(model, input_shape):
     model.eval()
     orig_model = deepcopy(model)
     dummy_input = torch.randn(input_shape)
-    folded_model = mt.fold_batch_norms_inference(model, dummy_input=dummy_input)
+    folded_model = mt.fold_batch_norms(model, dummy_input=dummy_input, inference=True)
     assert type(folded_model.seq[0]) == type(orig_model.seq[0])
     assert type(folded_model.seq[1]) == nn.Identity
 
diff --git a/tests/test_post_train_quant.py b/tests/test_post_train_quant.py
index 8bdc3fb69034ee53b0c531138561762db9912bb5..685c956cf660ad71898d076e222ad5df878d5602 100644
--- a/tests/test_post_train_quant.py
+++ b/tests/test_post_train_quant.py
@@ -25,11 +25,28 @@ from copy import deepcopy
 from distiller.quantization import RangeLinearQuantParamLayerWrapper, LinearQuantMode, ClipMode, \
     RangeLinearQuantConcatWrapper, RangeLinearQuantEltwiseMultWrapper, RangeLinearQuantEltwiseAddWrapper, \
     PostTrainLinearQuantizer
+from distiller.quantization.range_linear import _get_quant_params_from_tensor, _get_quant_params_from_stats_dict,\
+    TensorQuantMetadata
+from distiller.quantization import q_utils
 from distiller.data_loggers import QuantCalibrationStatsCollector, collector_context
 import distiller.modules
 from common import WrappedSequential
 
 
+def attach_quant_metadata(t, num_bits, quant_mode, stats=None, clip_mode=ClipMode.NONE, per_channel=False,
+                          num_stds=None, scale_approx_mult_bits=None):
+    if stats is None:
+        scale, zp = _get_quant_params_from_tensor(t, num_bits, quant_mode, clip_mode, per_channel, num_stds,
+                                                  scale_approx_mult_bits)
+    else:
+        scale, zp = _get_quant_params_from_stats_dict(stats, num_bits, quant_mode, clip_mode, num_stds,
+                                                      scale_approx_mult_bits)
+    signed = quant_mode != LinearQuantMode.ASYMMETRIC_UNSIGNED
+    min_q_val, max_q_val = q_utils.get_quantized_range(num_bits, signed)
+    t.quant_metadata = TensorQuantMetadata(scale, zp, min_q_val, max_q_val)
+    return t
+
+
 ###############################################################################
 # Test Convolution
 ###############################################################################
@@ -121,6 +138,10 @@ def test_conv_layer_wrapper(conv_input, conv_weights, mode, clip_acts, per_chann
     model = RangeLinearQuantParamLayerWrapper(layer, 8, 8, mode=mode, clip_acts=clip_acts,
                                               per_channel_wts=per_channel_wts, activation_stats=conv_stats)
 
+    input_stats = None if conv_stats is None else conv_stats['inputs'][0]
+    conv_input = attach_quant_metadata(conv_input, 8, mode, stats=input_stats, clip_mode=clip_acts,
+                                       per_channel=False, num_stds=None, scale_approx_mult_bits=None)
+
     with pytest.raises(RuntimeError):
         model(conv_input)
 
@@ -174,6 +195,9 @@ def test_linear_layer_wrapper(linear_input, linear_weights, linear_bias,
     model = RangeLinearQuantParamLayerWrapper(layer, 8, 8, mode=mode, clip_acts=clip_acts,
                                               per_channel_wts=per_channel_wts)
 
+    linear_input = attach_quant_metadata(linear_input, 8, mode, stats=None, clip_mode=clip_acts,
+                                         per_channel=False, num_stds=None, scale_approx_mult_bits=None)
+
     with pytest.raises(RuntimeError):
         model(linear_input)
 
@@ -196,7 +220,7 @@ def inputs():
     in_1_b_0 = torch.tensor([[[[-3, 6], [0, 8]], [[4, 10], [-7, 1]]]], dtype=torch.float32)
     in_1_b_1 = torch.tensor([[[[-100, 50], [6, 12]], [[80, -30], [-16, 3]]]], dtype=torch.float32)
     in_1 = torch.cat((in_1_b_0, in_1_b_1), 0)
-    return in_0, in_1
+    return [in_0, in_1]
 
 
 input_stats = OrderedDict()
@@ -264,6 +288,11 @@ def test_concat_layer_wrapper(inputs, concat_stats, mode, clip_acts, expected_ou
         # Check exception on no stats
         RangeLinearQuantConcatWrapper(layer, 8, mode, clip_acts, activation_stats=None)
 
+    for idx in range(len(inputs)):
+        inputs[idx] = attach_quant_metadata(inputs[idx], 8, mode, stats=concat_stats['inputs'][idx],
+                                            clip_mode=clip_acts, per_channel=False, num_stds=None,
+                                            scale_approx_mult_bits=None)
+
     model = RangeLinearQuantConcatWrapper(layer, 8, mode, clip_acts, concat_stats)
     model.eval()
     output = model(*inputs)
@@ -319,6 +348,11 @@ def test_eltwise_mult_layer_wrapper(inputs, eltwise_mult_stats, mode, clip_acts,
         # Check exception on no stats
         RangeLinearQuantEltwiseMultWrapper(layer, 8, mode, clip_acts, activation_stats=None)
 
+    for idx in range(len(inputs)):
+        inputs[idx] = attach_quant_metadata(inputs[idx], 8, mode, stats=eltwise_mult_stats['inputs'][idx],
+                                            clip_mode=clip_acts, per_channel=False, num_stds=None,
+                                            scale_approx_mult_bits=None)
+
     model = RangeLinearQuantEltwiseMultWrapper(layer, 8, mode, clip_acts, eltwise_mult_stats)
     model.eval()
     output = model(*inputs)
@@ -374,6 +408,11 @@ def test_eltwise_add_layer_wrapper(inputs, eltwise_add_stats, mode, clip_acts, e
         # Check exception on no stats
         RangeLinearQuantEltwiseAddWrapper(layer, 8, mode, clip_acts, activation_stats=None)
 
+    for idx in range(len(inputs)):
+        inputs[idx] = attach_quant_metadata(inputs[idx], 8, mode, stats=eltwise_add_stats['inputs'][idx],
+                                            clip_mode=clip_acts, per_channel=False, num_stds=None,
+                                            scale_approx_mult_bits=None)
+
     model = RangeLinearQuantEltwiseAddWrapper(layer, 8, mode, clip_acts, eltwise_add_stats)
     model.eval()
     output = model(*inputs)
@@ -439,8 +478,8 @@ def test_override_no_clip(overrides, e_clip_acts, e_n_stds, rnn_model, rnn_model
                                          model_activation_stats=rnn_model_stats)
     quantizer.prepare_model(torch.randn(1, 1, 20))
     assert isinstance(quantizer.model.rnn.cells[0].eltwisemult_hidden, RangeLinearQuantEltwiseMultWrapper)
-    assert quantizer.model.rnn.cells[0].eltwisemult_hidden.clip_acts == e_clip_acts
-    assert quantizer.model.rnn.cells[0].eltwisemult_hidden.clip_n_stds == e_n_stds
+    assert quantizer.model.rnn.cells[0].eltwisemult_hidden.output_quant_settings.clip_mode == e_clip_acts
+    assert quantizer.model.rnn.cells[0].eltwisemult_hidden.output_quant_settings.clip_n_stds == e_n_stds
 
 
 ###############################################################################
@@ -551,7 +590,7 @@ def test_stats_fusion_just_bn():
 @pytest.mark.parametrize(
     'act_type, act_as_module, bn_out_stats, conv_out_expected_stats',
     [
-        ('relu', True, stats_entry(-5., 5., -3., 3., 0., 0.5), None),
+        ('relu', True, stats_entry(-5., 5., -3., 3., 0., 0.5), stats_entry(0., 5., 0, 3., 0., 0.5)),
         ('relu', False, stats_entry(-5., 5., -3., 3., 0., 0.5), stats_entry(0., 5., 0, 3., 0., 0.5)),
         ('relu', False, stats_entry(1., 5., 2., 3., 2.5, 0.5), stats_entry(1., 5., 2., 3., 2.5, 0.5)),
         ('relu', False, stats_entry(-5., -1., -4., -2., -2.5, 0.5), stats_entry(0., 0, 0, 0., -2.5, 0.5)),
@@ -577,16 +616,9 @@ def test_stats_fusion_sequential(act_type, act_as_module, bn_out_stats, conv_out
 
     expected = deepcopy(stats)
     expected.pop('bn')  # After BN folding BN stats are removed
-    if act_type == 'relu':
-        if act_as_module:
-            expected['conv']['output'] = deepcopy(stats['act']['output'])
-            expected['act']['inputs'][0] = deepcopy(stats['act']['output'])
-        else:
-            expected['conv']['output'] = conv_out_expected_stats
-    else:
-        expected['conv']['output'] = conv_out_expected_stats
-        if act_as_module:
-            expected['act']['inputs'][0] = conv_out_expected_stats
+    expected['conv']['output'] = conv_out_expected_stats
+    if act_as_module:
+        expected['act']['inputs'][0] = conv_out_expected_stats
 
     assert quantizer.model_activation_stats == expected
 
diff --git a/tests/test_quantizer.py b/tests/test_quantizer.py
index 049e563eda6d0c8034a9414af0bef43b58181716..467aef8e3ddf64119dff7d5dd98c664a7845de68 100644
--- a/tests/test_quantizer.py
+++ b/tests/test_quantizer.py
@@ -396,13 +396,6 @@ def test_param_quantization(model, optimizer, qbits, overrides, explicit_expecte
 
 
 def test_overridable_args(model, optimizer, train_with_fp_copy):
-    model_copy = deepcopy(model)
-    conv_override = OrderedDict([(acts_key, None), (wts_key, None), (bias_key, None), ('prop', 123)])
-    overrides = OrderedDict([('conv1', conv_override)])
-    q = DummyQuantizer(model_copy, optimizer=optimizer, overrides=overrides, train_with_fp_copy=train_with_fp_copy)
-    pytest_raises_wrapper(ValueError, 'Expecting ValueError when overriding args without overriding bits',
-                          q.prepare_model)
-
     model_copy = deepcopy(model)
     conv_override = OrderedDict([(acts_key, 8), (wts_key, 8), (bias_key, 32), ('prop', 123), ('unexpetcted_prop', 456)])
     overrides = OrderedDict([('conv1', conv_override)])
diff --git a/tests/test_summarygraph.py b/tests/test_summarygraph.py
index 293f69bd4049ad41baf7d88a827c977689caed51..98feba80f7aefed82e3bd3e101e5266a44455d91 100755
--- a/tests/test_summarygraph.py
+++ b/tests/test_summarygraph.py
@@ -125,7 +125,7 @@ def test_layer_search(parallel, denorm_names):
     assert preds == prefix_strs(['conv1'], prefix)
 
     preds = g.predecessors_f('layer1.1.conv1', 'Conv', [], logging, denorm_names=denorm_names)
-    assert preds == prefix_strs(['layer1.0.conv2', 'conv1'], prefix)
+    assert preds == prefix_strs(['conv1', 'layer1.0.conv2'], prefix)
 
 
 def test_vgg():
@@ -306,7 +306,7 @@ def test_scope_name_workarounds():
     # may fix them, in which case we can remove the workarounds
     sg = SummaryGraph(m, dummy_input, apply_scope_name_workarounds=False)
     names, types = zip(*[(op_name, op['type']) for op_name, op in sg.ops.items()])
-    assert names == ('drop1', 'drop2', 'drop2__1', 'relu2', 'drop3')
+    assert names == ('drop1', 'drop2', 'drop2_Gemm_1', 'relu2', 'drop3')
     assert types == expected_types
 
     # Now test with the workarounds