From 0b493fd30af03550d6e1e33c9d83c0f3f83b90c5 Mon Sep 17 00:00:00 2001
From: Lev Zlotnik <46742999+levzlotnik@users.noreply.github.com>
Date: Sun, 2 Feb 2020 12:22:10 +0200
Subject: [PATCH] Loss Aware Post Train Quantization search (#432)

"Loss Aware Post-Training Quantization" (Nahshan et al., 2019)

Paper: https://arxiv.org/abs/1911.07190
Reference implementation:
  https://github.com/ynahshan/nn-quantization-pytorch/tree/master/lapq

Proper documentation is still TODO, for now see the example YAML file
at 'examples/quantization/post_train_quant/resnet18_imagenet_post_train_lapq.yaml'

* Implemented in distiller/quantization/ptq_coordinate_search.py
* At the moment that file both the model-independent algorithm
  implementation and image-classification specific sample script.
  Still TODO: Refactor that

* Post train quantization changes (range_linear):
  * Added getters/setters for quantization parameters (scale/zero_point)
    and clipping values
  * Add option to save backup of FP32 weights to allow re-quantization
    after quantizer was created.
  * Add option to clip weights in addition to activations
  * Fix fusions to not occur only when activations aren't quantized
  * RangeLinearFakeQuantWrapper:
    * Make inputs quantization optional
    * In case of ReLU + ACIQ, clip according to input stats

* Data loaders:
  * Add option to not load train set at all from disk (to speed up
    loading time in post-training runs)
  * Modified "image_classifier.py" accordingly
---
 distiller/apputils/data_loaders.py            | 134 +++--
 distiller/apputils/image_classifier.py        |  13 +-
 .../quantization/ptq_coordinate_search.py     | 515 ++++++++++++++++++
 distiller/quantization/quantizer.py           |   5 +
 distiller/quantization/range_linear.py        | 500 ++++++++++++++---
 distiller/utils.py                            |  10 +-
 .../resnet18_imagenet_post_train_lapq.yaml    | 101 ++++
 tests/test_post_train_quant.py                |  55 +-
 8 files changed, 1171 insertions(+), 162 deletions(-)
 create mode 100644 distiller/quantization/ptq_coordinate_search.py
 create mode 100644 examples/quantization/post_train_quant/resnet18_imagenet_post_train_lapq.yaml

diff --git a/distiller/apputils/data_loaders.py b/distiller/apputils/data_loaders.py
index 2394aa9..ae18998 100755
--- a/distiller/apputils/data_loaders.py
+++ b/distiller/apputils/data_loaders.py
@@ -66,7 +66,7 @@ def __dataset_factory(dataset):
 
 def load_data(dataset, data_dir, batch_size, workers, validation_split=0.1, deterministic=False,
               effective_train_size=1., effective_valid_size=1., effective_test_size=1.,
-              fixed_subset=False, sequential=False):
+              fixed_subset=False, sequential=False, test_only=False):
     """Load a dataset.
 
     Args:
@@ -94,29 +94,34 @@ def load_data(dataset, data_dir, batch_size, workers, validation_split=0.1, dete
                             effective_valid_size=effective_valid_size, 
                             effective_test_size=effective_test_size,
                             fixed_subset=fixed_subset,
-                            sequential=sequential)
+                            sequential=sequential,
+                            test_only=test_only)
 
 
-def mnist_get_datasets(data_dir):
+def mnist_get_datasets(data_dir, load_train=True, load_test=True):
     """Load the MNIST dataset."""
-    train_transform = transforms.Compose([
-        transforms.ToTensor(),
-        transforms.Normalize((0.1307,), (0.3081,))
-    ])
-    train_dataset = datasets.MNIST(root=data_dir, train=True,
-                                   download=True, transform=train_transform)
-
-    test_transform = transforms.Compose([
-        transforms.ToTensor(),
-        transforms.Normalize((0.1307,), (0.3081,))
-    ])
-    test_dataset = datasets.MNIST(root=data_dir, train=False,
-                                  transform=test_transform)
+    train_dataset = None
+    if load_train:
+        train_transform = transforms.Compose([
+            transforms.ToTensor(),
+            transforms.Normalize((0.1307,), (0.3081,))
+        ])
+        train_dataset = datasets.MNIST(root=data_dir, train=True,
+                                       download=True, transform=train_transform)
+
+    test_dataset = None
+    if load_test:
+        test_transform = transforms.Compose([
+            transforms.ToTensor(),
+            transforms.Normalize((0.1307,), (0.3081,))
+        ])
+        test_dataset = datasets.MNIST(root=data_dir, train=False,
+                                      transform=test_transform)
 
     return train_dataset, test_dataset
 
 
-def cifar10_get_datasets(data_dir):
+def cifar10_get_datasets(data_dir, load_train=True, load_test=True):
     """Load the CIFAR10 dataset.
 
     The original training dataset is split into training and validation sets (code is
@@ -134,28 +139,32 @@ def cifar10_get_datasets(data_dir):
     [1] C.-Y. Lee, S. Xie, P. Gallagher, Z. Zhang, and Z. Tu. Deeply Supervised Nets.
     arXiv:1409.5185, 2014
     """
-    train_transform = transforms.Compose([
-        transforms.RandomCrop(32, padding=4),
-        transforms.RandomHorizontalFlip(),
-        transforms.ToTensor(),
-        transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
-    ])
-
-    train_dataset = datasets.CIFAR10(root=data_dir, train=True,
-                                     download=True, transform=train_transform)
-
-    test_transform = transforms.Compose([
-        transforms.ToTensor(),
-        transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
-    ])
-
-    test_dataset = datasets.CIFAR10(root=data_dir, train=False,
-                                    download=True, transform=test_transform)
+    train_dataset = None
+    if load_train:
+        train_transform = transforms.Compose([
+            transforms.RandomCrop(32, padding=4),
+            transforms.RandomHorizontalFlip(),
+            transforms.ToTensor(),
+            transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
+        ])
+
+        train_dataset = datasets.CIFAR10(root=data_dir, train=True,
+                                         download=True, transform=train_transform)
+
+    test_dataset = None
+    if load_test:
+        test_transform = transforms.Compose([
+            transforms.ToTensor(),
+            transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
+        ])
+
+        test_dataset = datasets.CIFAR10(root=data_dir, train=False,
+                                        download=True, transform=test_transform)
 
     return train_dataset, test_dataset
 
 
-def imagenet_get_datasets(data_dir):
+def imagenet_get_datasets(data_dir, load_train=True, load_test=True):
     """
     Load the ImageNet dataset.
     """
@@ -164,23 +173,27 @@ def imagenet_get_datasets(data_dir):
     normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                      std=[0.229, 0.224, 0.225])
 
-    train_transform = transforms.Compose([
-        transforms.RandomResizedCrop(224),
-        transforms.RandomHorizontalFlip(),
-        transforms.ToTensor(),
-        normalize,
-    ])
+    train_dataset = None
+    if load_train:
+        train_transform = transforms.Compose([
+            transforms.RandomResizedCrop(224),
+            transforms.RandomHorizontalFlip(),
+            transforms.ToTensor(),
+            normalize,
+        ])
 
-    train_dataset = datasets.ImageFolder(train_dir, train_transform)
+        train_dataset = datasets.ImageFolder(train_dir, train_transform)
 
-    test_transform = transforms.Compose([
-        transforms.Resize(256),
-        transforms.CenterCrop(224),
-        transforms.ToTensor(),
-        normalize,
-    ])
+    test_dataset = None
+    if load_test:
+        test_transform = transforms.Compose([
+            transforms.Resize(256),
+            transforms.CenterCrop(224),
+            transforms.ToTensor(),
+            normalize,
+        ])
 
-    test_dataset = datasets.ImageFolder(test_dir, test_transform)
+        test_dataset = datasets.ImageFolder(test_dir, test_transform)
 
     return train_dataset, test_dataset
 
@@ -263,14 +276,25 @@ def _get_sampler(data_source, effective_size, fixed_subset=False, sequential=Fal
 
 def get_data_loaders(datasets_fn, data_dir, batch_size, num_workers, validation_split=0.1, deterministic=False,
                      effective_train_size=1., effective_valid_size=1., effective_test_size=1., fixed_subset=False,
-                     sequential=False):
-    train_dataset, test_dataset = datasets_fn(data_dir)
+                     sequential=False, test_only=False):
+    train_dataset, test_dataset = datasets_fn(data_dir, load_train=not test_only, load_test=True)
 
     worker_init_fn = None
     if deterministic:
         distiller.set_deterministic()
         worker_init_fn = __deterministic_worker_init_fn
 
+    test_indices = list(range(len(test_dataset)))
+    test_sampler = _get_sampler(test_indices, effective_test_size, fixed_subset, sequential)
+    test_loader = torch.utils.data.DataLoader(test_dataset,
+                                              batch_size=batch_size, sampler=test_sampler,
+                                              num_workers=num_workers, pin_memory=True)
+
+    input_shape = __image_size(test_dataset)
+
+    if test_only:
+        return None, None, test_loader, input_shape
+
     num_train = len(train_dataset)
     indices = list(range(num_train))
 
@@ -296,13 +320,5 @@ def get_data_loaders(datasets_fn, data_dir, batch_size, num_workers, validation_
                                                    num_workers=num_workers, pin_memory=True,
                                                    worker_init_fn=worker_init_fn)
 
-    test_indices = list(range(len(test_dataset)))
-    test_sampler = _get_sampler(test_indices, effective_test_size, fixed_subset, sequential)
-    test_loader = torch.utils.data.DataLoader(test_dataset,
-                                              batch_size=batch_size, sampler=test_sampler,
-                                              num_workers=num_workers, pin_memory=True)
-
-    input_shape = __image_size(train_dataset)
-
     # If validation split was 0 we use the test set as the validation set
     return train_loader, valid_loader or test_loader, test_loader, input_shape
diff --git a/distiller/apputils/image_classifier.py b/distiller/apputils/image_classifier.py
index 172bfeb..3dd6339 100755
--- a/distiller/apputils/image_classifier.py
+++ b/distiller/apputils/image_classifier.py
@@ -470,13 +470,18 @@ def save_collectors_data(collectors, directory):
 
 
 def load_data(args, fixed_subset=False, sequential=False, load_train=True, load_val=True, load_test=True):
-    train_loader, val_loader, test_loader, _ =  apputils.load_data(args.dataset, 
+    test_only = not load_train and not load_val
+
+    train_loader, val_loader, test_loader, _ = apputils.load_data(args.dataset,
                               os.path.expanduser(args.data), args.batch_size,
                               args.workers, args.validation_split, args.deterministic,
                               args.effective_train_size, args.effective_valid_size, args.effective_test_size,
-                              fixed_subset, sequential)
-    msglogger.info('Dataset sizes:\n\ttraining=%d\n\tvalidation=%d\n\ttest=%d',
-                   len(train_loader.sampler), len(val_loader.sampler), len(test_loader.sampler))
+                              fixed_subset, sequential, test_only)
+    if test_only:
+        msglogger.info('Dataset sizes:\n\ttest=%d', len(test_loader.sampler))
+    else:
+        msglogger.info('Dataset sizes:\n\ttraining=%d\n\tvalidation=%d\n\ttest=%d',
+                       len(train_loader.sampler), len(val_loader.sampler), len(test_loader.sampler))
 
     loaders = (train_loader, val_loader, test_loader)
     flags = (load_train, load_val, load_test)
diff --git a/distiller/quantization/ptq_coordinate_search.py b/distiller/quantization/ptq_coordinate_search.py
new file mode 100644
index 0000000..cd2f4df
--- /dev/null
+++ b/distiller/quantization/ptq_coordinate_search.py
@@ -0,0 +1,515 @@
+#
+# Copyright (c) 2019 Intel Corporation
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#      http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+#
+# Implementation of "Loss Aware Post-Training Quantization" (Nahshan et al., 2019)
+#
+# Paper: https://arxiv.org/abs/1911.07190
+# Reference implementation: https://github.com/ynahshan/nn-quantization-pytorch/tree/master/lapq
+#
+
+import torch
+import torch.nn as nn
+from distiller.quantization.range_linear import PostTrainLinearQuantizer, ClipMode, \
+    RangeLinearQuantWrapper, RangeLinearEmbeddingWrapper, RangeLinearQuantParamLayerWrapper, \
+    is_post_train_quant_wrapper, LinearQuantMode
+from distiller.quantization import is_linear_quant_mode_asymmetric, is_linear_quant_mode_symmetric
+from functools import partial
+from distiller.summary_graph import SummaryGraph
+from distiller.model_transforms import fold_batch_norms
+import distiller.modules
+from distiller.data_loggers import collect_quant_stats
+from collections import OrderedDict
+from itertools import count
+import logging
+from copy import deepcopy
+import distiller.apputils.image_classifier as classifier
+import os
+import distiller.apputils as apputils
+import scipy.optimize as opt
+import numpy as np
+
+
+def quant_params_dict2vec(p_dict, search_clipping=False):
+    """
+    Convert quantization params dictionary returned by post-train quantizer to a numpy array that can be used
+    with scipy.opt.minimize
+    """
+    keys = []
+    vals = []
+    for k, v in p_dict.items():
+        if search_clipping and isinstance(v, tuple):
+            # When both min and amx values are optimized, we need to concatenate them in the array
+            # We create dual matching keys so it's easy to convert back to a dict
+            keys += [k + '_min', k + '_max']
+            vals += [v[0].item(), v[1].item()]
+        else:
+            keys.append(k)
+            vals.append(v.item())
+
+    return keys, np.array(vals)
+
+
+def quant_params_vec2dict(keys, vals, search_clipping=False):
+    """
+    Convert the vector(s) created by quant_params_dict2vec to a dictionary of quantization parameters that
+    the post-training quantizer API can digest
+    """
+    res = OrderedDict()
+    for idx, k in enumerate(keys):
+        if search_clipping and k.endswith('_min'):
+            res[k[:-4]] = sorted((vals[idx], vals[idx + 1]))
+        elif search_clipping and k.endswith('_max'):
+            continue
+        else:
+            res[k] = abs(vals[idx])
+    return res
+
+
+def lp_loss(x: torch.Tensor, y: torch.Tensor, p):
+    tmp = (x - y).abs_().pow_(p)
+    loss = (torch.sum(tmp) / x.numel()).item()
+    return loss
+
+
+def _check_qp_vec(keys, qp_vec, quant_mode=LinearQuantMode.SYMMETRIC, search_clipping=False):
+    if is_linear_quant_mode_symmetric(quant_mode):
+        return all(qp_vec > 0)
+    if not search_clipping:
+        idxs_scales = np.array(['scale' in key for key in keys])
+        qp_vec_scales = qp_vec[idxs_scales]
+        return all(qp_vec_scales > 0)
+
+
+l1_loss = partial(lp_loss, p=1)
+l2_loss = partial(lp_loss, p=2)
+l3_loss = partial(lp_loss, p=3)
+
+
+_INIT_MODES = {
+    'NONE': ClipMode.NONE, 'AVG': ClipMode.AVG, 'LAPLACE': ClipMode.LAPLACE, 'GAUSS': ClipMode.GAUSS,
+    'L1': l1_loss, 'L2': l2_loss, 'L3': l3_loss
+}
+
+
+def _init_mode_from_str(init_mode_str):
+    if init_mode_str not in _INIT_MODES:
+        raise ValueError('Unsupported init mode \'%s\'. '
+                         'The supported init modes are: %s.' % (init_mode_str, _INIT_MODES))
+    return _INIT_MODES[init_mode_str]
+
+
+def optimize_for_layer(layer, quantized_layer, loss_fn, input, method=None, search_clipping=False):
+    """
+    Searches for optimal linear quantization parameters (scale, zero_point) for a layer
+    with respect to the loss function. Assumes loss_fn is of the signature `loss_fn(y, y_q)->float`
+
+    We perform the initialization a bit differently compared to the paper/reference implementation:
+    * In the reference:
+      * Weights and activations are initialized based on quantization loss of their respective tensors.
+      * Activations are initialized "online", meaning the input to the layer N being initialized is the
+        output of the already quantized layer N-1.
+    * In this implementation:
+      * For a given layer, we initialize both activations and weights together (as applicable) based on the
+        LP loss between the quantized layer output and the FP32 layer output.
+      * But, we don't do "online" initialization. That is, each layer is initialized independently from the
+        quantization parameters obtained for earlier layers.
+
+    Args:
+        layer (nn.Module): the original, pre-quantized, layer.
+        quantized_layer (RangeLinearQuantWrapper or RangeLinearEmbeddingWrapper): the post-quantized layer.
+        loss_fn (callable): the loss function to optimize with respect to it.
+        method (str or callable): the method of optimization, as will be used by scipy.optimize.minimize.
+        search_clipping (bool): if set, optimize clipping values, otherwise optimize scale factor
+    Returns:
+        quantized_layer after optimization
+    """
+    params_gen = quantized_layer.named_linear_quant_params(filter=True) if not search_clipping \
+        else quantized_layer.named_clipping(filter=True)
+    init_qp_dict = OrderedDict(params_gen)
+
+    keys, init_qp_vec = quant_params_dict2vec(init_qp_dict, search_clipping)
+
+    def feed_forward_fn(qp_vec):
+        qp_dict = quant_params_vec2dict(keys, qp_vec, search_clipping)
+        quantized_layer.update_linear_quant_params(qp_dict)
+        # Using cloned input, required if the layer is inplace
+        y = layer(input.clone().detach())
+        if getattr(quantized_layer, 'clip_half_range', False):
+            torch.relu_(y)
+        q_y = quantized_layer(input.clone().detach())
+        loss = loss_fn(y, q_y)
+        return loss
+
+    result = opt.minimize(feed_forward_fn, init_qp_vec, method=method)  # type: opt.OptimizeResult
+    return quantized_layer
+
+
+def get_input_for_layer(model, layer_name, eval_fn):
+    layer = dict(model.named_modules())[layer_name]  # type: nn.Module
+    layer_inputs = []
+
+    def hook_layer_input(module, input):
+        layer_inputs.append(input[0].clone().detach())
+
+    handle = layer.register_forward_pre_hook(hook_layer_input)
+    eval_fn(model)
+    assert len(layer_inputs) == 1
+    handle.remove()
+    return layer_inputs[0]
+
+
+def init_layer_linear_quant_params(quantizer, original_model, layer_name, init_mode,
+                                   init_mode_method='Powell', eval_fn=None, search_clipping=False):
+    """
+    Initializes a layer's linear quant parameters.
+    This is done to set the scipy.optimize.minimize initial guess.
+    Args:
+        quantizer (PostTrainLinearQuantizer): the quantizer, **after** calling prepare model.
+        original_model (nn.Module): the original, pre-quantized, model.
+        layer_name (str): the name of the layer.
+        init_mode (ClipMode or callable or str): the initialization mode.
+          If ClipMode, the initialization will be according to the respective ClipMode.
+          If callable - init_mode will be treated as a loss function between the activations pre and post-quantization,
+            and the initialization process will attempt to find the minimum of that loss function.
+            E.g. if l1_loss has been passed, the initialization vector will be
+              scale, zero_point = argmin_{s, zp} (l1_loss(layer(input), q_layer(input; s, zp)))
+          If str - the mode will be chosen from a list of options. The options are:
+            [NONE, AVG, LAPLACE, GAUSS, L1, L2 ,L3].
+          Defaults to ClipMode.NONE
+        init_mode_method (str or callable): applicable only in the case of init_mode = 'L1/2/3' or callable.
+          chooses the minimization method for finding the local argmin_{s, zp}.
+          Defaults to 'Powell'
+        eval_fn: evaluation function for the model. Assumed it has a signature of the form
+          `eval_fn(model)->float`. this is the function to be minimized by the optimization algorithm.
+          applicable only in the case of init_mode = 'L1/2/3' or callable.
+        search_clipping (bool): if set, optimize clipping values, otherwise optimize scale factor
+    """
+    denorm_layer_name = distiller.denormalize_module_name(quantizer.model, layer_name)
+    msglogger.info(denorm_layer_name)
+    if isinstance(init_mode, str):
+        init_mode = _init_mode_from_str(init_mode)
+    layer = dict(original_model.named_modules())[layer_name]
+    local_args, local_kwargs = quantizer.modules_processed_args[denorm_layer_name]
+    if isinstance(init_mode, ClipMode):
+        local_kwargs['clip_acts'] = init_mode
+    replace_fn = quantizer.replacement_factory.get(type(layer), quantizer.default_repalcement_fn)
+    quantized_layer = replace_fn(deepcopy(layer), *local_args, **local_kwargs).eval()
+    if not is_post_train_quant_wrapper(quantized_layer, False):
+        # the module wasn't quantized, nothing to do here
+        return
+
+    if callable(init_mode):
+        input_for_layer = get_input_for_layer(original_model, layer_name, eval_fn)
+        quantized_layer = optimize_for_layer(layer, quantized_layer, init_mode, input_for_layer, init_mode_method,
+                                             search_clipping=search_clipping)
+
+    distiller.model_setattr(quantizer.model, denorm_layer_name, quantized_layer)
+    quantizer.model.eval()
+
+
+def init_linear_quant_params(quantizer, original_model, eval_fn, dummy_input, init_mode,
+                             init_mode_method=None, search_clipping=False):
+    """
+    Initializes all linear quantization parameters of the model.
+    Args:
+        quantizer (PostTrainLinearQuantizer): the quantizer, **after** calling prepare model.
+        original_model (nn.Module): the original, pre-quantized, model.
+        init_mode (ClipMode or callable or str or dict): See `init_layer_linear_qaunt_params`.
+          if init_mode is dict - init_mode is configuration for the different layers,
+          i.e. init_mode = Dict[layer_name:str, init_mode_layer: ClipMode or callable or str].
+        eval_fn: evaluation function for the model. Assumed it has a signature of the form
+          `eval_fn(model)->float`. this is the function to be minimized by the optimization algorithm.
+          Note - unlike in `init_layer_linear_quant_params`, this argument is required here.
+        dummy_input: dummy sample input to the model
+        init_mode_method: See `init_layer_linear_qaunt_params`.
+        search_clipping (bool): if set, optimize clipping values, otherwise optimize scale factor
+    """
+    original_model = distiller.make_non_parallel_copy(original_model)
+    layers_topological_order = SummaryGraph(original_model, dummy_input).layers_topological_order()
+    q_named_modules = OrderedDict(quantizer.model.named_modules())
+    for module_name in layers_topological_order:
+        # check to see if it was quantized:
+        q_module = q_named_modules[distiller.denormalize_module_name(quantizer.model, module_name)]
+        if not is_post_train_quant_wrapper(q_module, False):
+            continue
+        module_init_mode = init_mode[module_name] if isinstance(init_mode, dict) else init_mode
+        msglogger.debug('Initializing layer \'%s\' using %s mode' % (module_name, module_init_mode))
+        init_layer_linear_quant_params(quantizer, original_model, module_name, module_init_mode,
+                                       init_mode_method=init_mode_method,
+                                       eval_fn=eval_fn,
+                                       search_clipping=search_clipping)
+    del original_model
+
+    quantizer._post_prepare_model()
+    quantizer.model.eval()
+
+
+def get_default_args():
+    parser = classifier.init_classifier_compression_arg_parser()
+    parser.add_argument('--opt-maxiter', dest='maxiter', default=None, type=int,
+                        help='Max iteration for minimization method.')
+    parser.add_argument('--opt-maxfev', dest='maxfev', default=None, type=int,
+                        help='Max iteration for minimization method.')
+    parser.add_argument('--opt-method', dest='method', default='Powell',
+                        help='Minimization method used by scip.optimize.minimize.')
+    parser.add_argument('--opt-bh', dest='basinhopping', action='store_true', default=False,
+                        help='Use scipy.optimize.basinhopping stochastic global minimum search.')
+    parser.add_argument('--opt-bh-niter', dest='niter', default=100,
+                        help='Number of iterations for the basinhopping algorithm.')
+    parser.add_argument('--opt-init-mode', dest='init_mode', default='NONE',
+                        choices=list(_INIT_MODES),
+                        help='The mode of quant initalization. Choices: ' + '|'.join(list(_INIT_MODES)))
+    parser.add_argument('--opt-init-method', dest='init_mode_method',
+                        help='If --opt-init-mode was specified as L1/L2/L3, this specifies the method of '
+                             'minimization.')
+    parser.add_argument('--opt-val-size', type=float, default=1,
+                        help='Use portion of the test size.')
+    parser.add_argument('--opt-eval-memoize-dataloader', dest='memoize_dataloader', action='store_true', default=False,
+                        help='Stores the input batch in memory to optimize performance.')
+    parser.add_argument('--base-score', type=float, default=None)
+    parser.add_argument('--opt-search-clipping', dest='search_clipping', action='store_true',
+                        help='Search on clipping values instead of scale/zero_point.')
+    args = parser.parse_args()
+    return args
+
+
+def validate_quantization_settings(args, quantized_model):
+    if args.search_clipping:
+        return
+    for n, m in quantized_model.named_modules():
+        if not is_post_train_quant_wrapper(m, False):
+            continue
+
+        err_msg = 'Detected asymmetric quantization of {}. ' \
+                  'Switch to symmetric quantization or enable search_clipping.'
+        if not isinstance(m, RangeLinearEmbeddingWrapper):
+            if m.output_quant_settings.num_bits and \
+                    is_linear_quant_mode_asymmetric(m.mode.activations) and \
+                    not m.clip_half_range:
+                raise ValueError(err_msg.format('activations without fused ReLU'))
+        if isinstance(m, (RangeLinearEmbeddingWrapper, RangeLinearQuantParamLayerWrapper)):
+            if is_linear_quant_mode_asymmetric(m.mode.weights):
+                raise ValueError(err_msg.format('weights'))
+
+
+def ptq_coordinate_search(model, dummy_input, eval_fn, method='Powell', options=None,
+                          act_stats=None, args=None, fold_sequences=True, basinhopping=False,
+                          init_args=None, minimizer_kwargs=None,
+                          test_fn=None):
+    """
+    Searches for the optimal post-train quantization configuration (scale/zero_points)
+    for a model using numerical methods, as described by scipy.optimize.minimize.
+    Args:
+        model (nn.Module): model to quantize
+        dummy_input: an sample expected input to the model
+        eval_fn (callable): evaluation function for the model. Assumed it has a signature of the form
+          `eval_fn(model)->float`. this is the function to be minimized by the optimization algorithm.
+        method (str or callable): minimization method as accepted by scipy.optimize.minimize.
+        options (dict or None): options for the scipy optimizer
+        act_stats (OrderedDict): dictionary of statistics per layer, including inputs and outputs.
+          for more context refer to collect_quant_stats.
+        args: arguments from command-line.
+        fold_sequences (bool): flag, indicates to fold sequences before performing the search.
+        basinhopping (bool): flag, indicates to use basinhopping as a global-minimization method,
+          will pass the `method` argument to `scipy.optimize.basinhopping`.
+        init_args (tuple): arguments for initializing the linear quantization parameters.
+          Refer to `init_linear_quant_params` for more details.
+        minimizer_kwargs (dict): the kwargs for scipy.optimize.minimize procedure.
+        test_fn (callable): a function to test the current performance of the model.
+    """
+    if fold_sequences:
+        model = fold_batch_norms(model, dummy_input)
+    if args is None:
+        args = get_default_args()
+    elif isinstance(args, dict):
+        updated_args = get_default_args()
+        updated_args.__dict__.update(args)
+        args = updated_args
+    original_model = deepcopy(model)
+
+    if not act_stats and not args.qe_config_file:
+        msglogger.info('Collecting stats for model...')
+        model_temp = distiller.utils.make_non_parallel_copy(model)
+        act_stats = collect_quant_stats(model_temp, eval_fn)
+        del model_temp
+        if args:
+            act_stats_path = '%s_act_stats.yaml' % args.arch
+            msglogger.info('Done. Saving act stats into %s' % act_stats_path)
+            distiller.yaml_ordered_save(act_stats_path, act_stats)
+            args.qe_stats_file = act_stats_path
+
+    # Preparing model and init conditions:
+    msglogger.info("Initializing quantizer...")
+    quantizer = PostTrainLinearQuantizer.from_args(model, args)
+
+    # Make sure weights are re-quantizable and clip-able
+    quantizer.save_fp_weights = True
+    quantizer.also_clip_weights = True
+
+    # Disable any user set activations clipping - we'll be using init_args
+    quantizer.clip_acts = ClipMode.NONE
+    for overrides_dict in quantizer.module_overrides_map.values():
+        overrides_dict.pop('clip_acts', None)
+
+    quantizer.prepare_model(dummy_input)
+    quantizer.model.eval()
+
+    validate_quantization_settings(args, quantizer.model)
+
+    msglogger.info("Initializing quantization parameters...")
+    init_args = init_args or (args.init_mode, args.init_mode_method)
+    init_linear_quant_params(quantizer, original_model, eval_fn, dummy_input, *init_args,
+                             search_clipping=args.search_clipping)
+
+    msglogger.info("Evaluating initial quantization score...")
+    best_data = {
+        'score': eval_fn(model),
+        'qp_dict': deepcopy(quantizer.linear_quant_params)
+    }
+    msglogger.info("Evaluation set loss after initialization %.3f" % best_data['score'])
+    if test_fn:
+        msglogger.info('Evaluating on full test set...')
+        l_top1, l_top5, l_loss = test_fn(quantizer.model)
+        msglogger.info('Test: \tloss=%.3f, top1=%.3f, top5=%.3f ' % (l_loss, l_top1, l_top5))
+
+    init_qp_dict = OrderedDict(quantizer.named_linear_quant_params(args.search_clipping, filter=True))
+    keys, init_qp_vec = quant_params_dict2vec(init_qp_dict, args.search_clipping)
+
+    iter_counter = count(1)
+    eval_counter = count(1)
+
+    def feed_forward_fn(qp_vec):
+        # if not _check_qp_vec(keys, qp_vec, quant_mode, args.search_clipping):
+        #     return 1e6
+        qp_dict = quant_params_vec2dict(keys, qp_vec, args.search_clipping)
+        quantizer.update_linear_quant_params(qp_dict)
+        loss = eval_fn(quantizer.model)
+
+        i = next(eval_counter)
+        if i % 20 == 0:
+            msglogger.info('%d evaluations: loss=%.3f' % (i, loss))
+
+        return loss
+
+    def callback(qp_vec):
+        score = feed_forward_fn(qp_vec)
+        i = next(iter_counter)
+        msglogger.info("Iteration %d: \t Score=%.3f" % (i, score))
+        if score < best_data['score']:
+            best_data['score'] = score
+            best_data['qp_dict'] = quant_params_vec2dict(keys, qp_vec, args.search_clipping)
+            msglogger.info("Saving current best quantization parameters.")
+        if test_fn:
+            msglogger.info('Evaluating on full test set...')
+            l_top1, l_top5, l_loss = test_fn(quantizer.model)
+            msglogger.info('Test: \tloss=%.3f, top1=%.3f, top5=%.3f ' % (l_loss, l_top1, l_top5))
+
+    options = options or OrderedDict()
+    if args.maxiter is not None:
+        options['maxiter'] = args.maxiter
+    if args.maxfev is not None:
+        options['maxfev'] = args.maxfev
+    minimizer_kwargs = minimizer_kwargs or OrderedDict()
+    minimizer_kwargs.update({
+        'method': method, 'options': options
+    })
+    basinhopping = basinhopping or args.basinhopping
+    if basinhopping:
+        msglogger.info('Using basinhopping global minimum search with "%s" local minimization method'%
+                       method)
+        res = opt.basinhopping(feed_forward_fn, init_qp_vec, args.niter, callback=callback,
+                               minimizer_kwargs=minimizer_kwargs)
+    else:
+        msglogger.info('Using "%s" minimization algorithm.' % method)
+        res = opt.minimize(feed_forward_fn, init_qp_vec, callback=callback, **minimizer_kwargs)
+
+    msglogger.info("Optimization done. Best configuration: %s" % best_data['qp_dict'])
+    return model, best_data['qp_dict']
+
+
+if __name__ == "__main__":
+    args = get_default_args()
+    args.epochs = float('inf')  # hack for args parsing so there's no error in epochs
+    cc = classifier.ClassifierCompressor(args, script_dir=os.path.dirname(__file__))
+
+    args = deepcopy(cc.args)
+
+    effective_test_size_bak = args.effective_test_size
+    args.effective_test_size = args.opt_val_size
+    eval_data_loader = classifier.load_data(args, load_train=False, load_val=False, load_test=True, fixed_subset=True)
+
+    args.effective_test_size = effective_test_size_bak
+    test_data_loader = classifier.load_data(args, load_train=False, load_val=False, load_test=True)
+
+    # logging
+    logging.getLogger().setLevel(logging.WARNING)
+    msglogger = logging.getLogger(__name__)
+    msglogger.setLevel(logging.INFO)
+
+    model = cc.model.eval()
+    device = next(model.parameters()).device
+
+    if args.memoize_dataloader:
+        memoized_data_loader = []
+        for images, targets in eval_data_loader:
+            batch = images.to(device), targets.to(device)
+            memoized_data_loader.append(batch)
+    else:
+        memoized_data_loader = None
+
+    def eval_fn(model):
+        if args.memoize_dataloader:
+            loss = 0
+            for images, targets in memoized_data_loader:
+                outputs = model(images)
+                loss += cc.criterion(outputs, targets).item()
+            loss = loss / len(memoized_data_loader)
+        else:
+            _, _, loss = classifier.test(eval_data_loader, model, cc.criterion, [cc.tflogger, cc.pylogger],
+                                           None, args)
+        return loss
+
+    def test_fn(model):
+        return classifier.test(test_data_loader, model, cc.criterion, [cc.tflogger, cc.pylogger], None, args)
+
+    args.device = device
+    if args.resumed_checkpoint_path:
+        args.load_model_path = args.resumed_checkpoint_path
+    if args.load_model_path:
+        msglogger.info("Loading checkpoint from %s" % args.load_model_path)
+        model = apputils.load_lean_checkpoint(model, args.load_model_path,
+                                              model_device=args.device)
+
+    if args.qe_stats_file:
+        msglogger.info("Loading stats from %s" % args.qe_stats_file)
+        with open(args.qe_stats_file, 'r') as f:
+            act_stats = distiller.yaml_ordered_load(f)
+    else:
+        act_stats = None
+
+    dummy_input = torch.rand(*model.input_shape, device=args.device)
+    model, qp_dict = ptq_coordinate_search(model, dummy_input, eval_fn, args.method,
+                                           args=args, act_stats=act_stats, test_fn=test_fn)
+
+    top1, top5, loss = test_fn(model)
+
+    msglogger.info("Arch: %s \tTest: \t top1 = %.3f \t top5 = %.3f \t loss = %.3f" %
+                   (args.arch, top1, top5, loss))
+    distiller.yaml_ordered_save('%s.quant_params_dict.yaml' % args.arch, qp_dict)
+
+    distiller.apputils.save_checkpoint(0, args.arch, model, extras={'top1': top1, 'qp_dict': qp_dict}, name=args.name,
+                                       dir=cc.logdir)
diff --git a/distiller/quantization/quantizer.py b/distiller/quantization/quantizer.py
index 6f4943c..1b33530 100644
--- a/distiller/quantization/quantizer.py
+++ b/distiller/quantization/quantizer.py
@@ -23,6 +23,7 @@ import torch.nn as nn
 import distiller
 import warnings
 from typing import Callable, Optional
+from copy import deepcopy
 
 msglogger = logging.getLogger()
 
@@ -187,6 +188,7 @@ class Quantizer(object):
 
         # A dictionary of replaced modules and their respective names.
         self.modules_processed = OrderedDict()
+        self.modules_processed_args = OrderedDict()
 
     def _add_qbits_entry(self, module_name, module_type, qbits):
         if module_type not in [nn.Conv2d, nn.Conv3d, nn.Linear, nn.Embedding]:
@@ -314,6 +316,9 @@ class Quantizer(object):
                         replace_msg(full_name, (module, new_module))
                         # Add to history of prepared submodules
                         self.modules_processed[module] = full_name, new_module
+                        # To allow recreating this wrapper later on
+                        valid_args = full_name, deepcopy(self.module_qbits_map)
+                        self.modules_processed_args[full_name] = valid_args, valid_kwargs
                         setattr(container, name, new_module)
 
                         # If a "leaf" module was replaced by a container, add the new layers to the QBits mapping
diff --git a/distiller/quantization/range_linear.py b/distiller/quantization/range_linear.py
index 287776b..5e3c7f8 100644
--- a/distiller/quantization/range_linear.py
+++ b/distiller/quantization/range_linear.py
@@ -16,6 +16,7 @@
 
 import torch.nn as nn
 import torch.nn.functional as f
+import numpy as np
 import argparse
 from enum import Enum
 from collections import OrderedDict, namedtuple
@@ -124,18 +125,22 @@ def _get_saturation_fn(quant_mode, clip_mode, num_stds, num_bits=None):
 def _get_quant_params_from_tensor(tensor, num_bits, mode, clip=ClipMode.NONE, per_channel=False, num_stds=None,
                                   half_range=False, scale_approx_mult_bits=None):
     if per_channel and tensor.dim() not in [2, 4]:
-        raise ValueError('Per channel quantization possible only with 2D or 4D tensors (linear or conv layer weights)')
+        raise UnsatisfiedRequirements('Per channel quantization possible only with '
+                                      '2D or 4D tensors (linear or conv layer weights)')
 
     if clip == ClipMode.N_STD:
         if per_channel:
             raise ValueError('N_STD clipping not supported with per-channel quantization')
         if num_stds is None:
-            raise ValueError('Clip mode set top N_STD but \'num_stds\' parameter not provided')
+            raise UnsatisfiedRequirements('Clip mode set top N_STD but \'num_stds\' parameter not provided')
 
     dim = 0 if clip == ClipMode.AVG or per_channel else None
     sat_fn = _get_saturation_fn(mode, clip, num_stds, num_bits)
     if is_linear_quant_mode_symmetric(mode):
         sat_val = sat_fn(tensor, dim)
+        if isinstance(sat_val, tuple):
+            assert len(sat_val) == 2
+            sat_val = torch.max(*sat_val)
         scale, zp = symmetric_linear_quantization_params(num_bits, sat_val,
                                                          restrict_qrange=mode == LinearQuantMode.SYMMETRIC_RESTRICTED)
     else:   # Asymmetric mode
@@ -193,6 +198,25 @@ def _get_quant_params_from_stats_dict(stats, num_bits, mode, clip=ClipMode.NONE,
     return scale, zp
 
 
+def _get_clipping_values(scale, zp, num_bits, mode):
+    """
+    Gets the saturation values induced by quantization values
+    Args:
+        scale, zp (torch.Tensor or float): quantization params
+        num_bits (int): number of bits
+        mode (LinearQuantMode): mode of quantization
+    Returns:
+        min, max : tuple[float, float]
+    """
+    device = scale.device if isinstance(scale, torch.Tensor) else 'cpu'
+    if is_linear_quant_mode_asymmetric(mode):
+        t = torch.tensor([0, 2**num_bits-1], device=device)
+    else:
+        t = torch.tensor([-2**(num_bits-1), 2**(num_bits-1)-1], device=device)
+    sat_min, sat_max = linear_dequantize(t, scale, zp)  # type: torch.Tensor
+    return sat_min, sat_max
+
+
 ###############################################################################
 # Post Training
 ###############################################################################
@@ -218,8 +242,7 @@ class QuantSettings(object):
         return '(num_bits={} ; quant_mode={} ; clip_mode={} ; clip_n_stds={} ; clip_half_range={}' \
                ' ; per_channel={})'.format(self.num_bits, _enum_to_str(self.quant_mode),
                                            _enum_to_str(self.clip_mode), self.clip_n_stds, self.clip_half_range,
-                                           self.per_channel
-        )
+                                           self.per_channel)
 
 
 def linear_quantize_clamp_with_metadata(t, inplace=False):
@@ -294,11 +317,15 @@ def add_post_train_quant_args(argparser):
     group.add_argument('--qe-no-clip-layers', '--qencl', type=str, nargs='+', metavar='LAYER_NAME', default=[],
                        help='List of layer names for which not to clip activations. Applicable '
                             'only if --qe-clip-acts is not \'none\'')
+    group.add_argument('--qe-no-quant-layers', '--qenql', type=str, nargs='+', metavar='LAYER_NAME', default=[],
+                        help='List of layer names for which to skip quantization.')
     group.add_argument('--qe-per-channel', '--qepc', action='store_true',
                        help='Enable per-channel quantization of weights (per output channel)')
     group.add_argument('--qe-scale-approx-bits', '--qesab', type=int, metavar='NUM_BITS',
                        help='Enables scale factor approximation using integer multiply + bit shift, using '
                             'this number of bits the integer multiplier')
+    group.add_argument('--qe-save-fp-weights', action='store_true',
+                       help='Allow weights requantization.')
 
     stats_group = group.add_mutually_exclusive_group()
     stats_group.add_argument('--qe-stats-file', type=str, metavar='PATH',
@@ -313,6 +340,25 @@ def add_post_train_quant_args(argparser):
                             'all other --qe* arguments are ignored)')
 
 
+class UnsatisfiedRequirements(Exception):
+    pass
+
+
+def _check_clipping_val(val, quant_mode, half_range):
+    if isinstance(val, float):
+        if is_linear_quant_mode_symmetric(quant_mode):
+            return -val, val
+        elif half_range:
+            return 0, val
+        raise ValueError('For asymmetric quantization, setting clipping values only allowed '
+                         'using both min/max values.')
+    if isinstance(val, (tuple, list, np.ndarray, torch.Tensor)):
+        assert all(distiller.is_scalar(v) for v in val), 'Elements of the clipping value must be scalar-like.'
+        assert len(val) == 2, 'Clipping value must have 2 elements.'
+        return tuple(val)
+    raise TypeError('Clipping value should be a scalar or an iterable of these')
+
+
 class RangeLinearQuantWrapper(nn.Module):
     """
     Base class for module which wraps an existing module with linear range-base quantization functionality
@@ -356,6 +402,11 @@ class RangeLinearQuantWrapper(nn.Module):
 
         self.preset_act_stats = False
         self.register_buffer('num_forwards', torch.zeros(1, dtype=torch.long))
+        self.register_buffer('force_readjust', torch.tensor(False))
+
+        # The accumulator is always signed
+        self.accum_min_q_val, self.accum_max_q_val = get_quantized_range(num_bits_accum, signed=True,
+                                                                         signed_restrict_qrange=False)
 
         # Activations not quantized - stop here
         if num_bits_acts is None:
@@ -391,9 +442,6 @@ class RangeLinearQuantWrapper(nn.Module):
         restrict_qrange = mode.activations == LinearQuantMode.SYMMETRIC_RESTRICTED
         self.acts_min_q_val, self.acts_max_q_val = get_quantized_range(num_bits_acts, signed=signed,
                                                                        signed_restrict_qrange=restrict_qrange)
-        # The accumulator is always signed
-        self.accum_min_q_val, self.accum_max_q_val = get_quantized_range(num_bits_accum, signed=True,
-                                                                         signed_restrict_qrange=False)
 
         if activation_stats:
             self.preset_act_stats = True
@@ -417,16 +465,73 @@ class RangeLinearQuantWrapper(nn.Module):
             scale, zp = _get_quant_params_from_stats_dict(activation_stats['output'], num_bits_acts, mode.activations,
                                                           clip_acts, clip_n_stds, clip_half_range,
                                                           scale_approx_mult_bits)
+            if not isinstance(scale, torch.Tensor):
+                scale, zp = torch.tensor(scale), torch.tensor(zp)
             self.register_buffer('output_scale', scale)
             self.register_buffer('output_zero_point', zp)
         else:
             self.preset_act_stats = False
 
-    def named_acts_quant_params(self):
+    def named_linear_quant_params(self, filter=False):
         if self.output_quant_settings.num_bits is not None and self.preset_act_stats:
             # Output scale buffers are saved in the model only when stats are used
             yield 'output_scale', self.output_scale
-            yield 'output_zero_point', self.output_zero_point
+            if not filter or (is_linear_quant_mode_asymmetric(self.mode.activations) and not self.clip_half_range):
+                yield 'output_zero_point', self.output_zero_point
+
+    def set_linear_quant_param(self, name, val):
+        if name in dict(self.named_clipping()):
+            setattr(self, name, val)
+        elif name not in dict(self.named_linear_quant_params()):
+            raise ValueError('%s is not a quantization parameter.' % name)
+        else:
+            getattr(self, name).data.fill_(val)
+        self.force_readjust.fill_(True)
+
+    def _check_requirements_output_clipping(self):
+        if not self.output_quant_settings.num_bits:
+            raise UnsatisfiedRequirements('Cannot retrieve clipping values because '
+                                          'the activations aren\'t quantized.')
+        if not self.preset_act_stats:
+            raise UnsatisfiedRequirements('Cannot retrieve clipping values '
+                                          'because the activations stats were not provided.')
+
+    @property
+    def output_clipping(self):
+        self._check_requirements_output_clipping()
+        bits = self.output_quant_settings.num_bits
+        scale, zp = self.output_scale, self.output_zero_point
+        return _get_clipping_values(scale, zp, bits, self.output_quant_settings.quant_mode)
+
+    @output_clipping.setter
+    def output_clipping(self, val):
+        """
+        Args:
+            val (float or tuple[float, float] or tuple[torch.Tensor, torch.Tensor]): the value to set
+        """
+        self._check_requirements_output_clipping()
+        qset = self.output_quant_settings
+        val_min, val_max = _check_clipping_val(val, qset.quant_mode, self.clip_half_range)
+        qset.clip_mode, qset.clip_half_range, qset.clip_n_stds = ClipMode.NONE, None, None
+        scale, zp = _get_quant_params_from_stats_dict({'min': val_min, 'max': val_max}, qset.num_bits, qset.quant_mode,
+                                                      scale_approx_mult_bits=self.scale_approx_mult_bits)
+        self.set_linear_quant_param('output_scale', scale.item())
+        self.set_linear_quant_param('output_zero_point', zp.item())
+
+    def named_clipping(self, filter=False):
+        val = self.output_clipping
+        if filter and (is_linear_quant_mode_symmetric(self.mode.activations) or self.clip_half_range):
+            val = val[1]
+        yield 'output_clipping', val
+
+    def update_linear_quant_params(self, new_config):
+        """
+        Updates all the quant params using a dictionary.
+        Args:
+             new_config (dict): the new configuration dict.
+        """
+        for name, val in new_config.items():
+            self.set_linear_quant_param(name, val)
 
     def forward(self, *inputs):
         if self.training:
@@ -617,8 +722,8 @@ class RangeLinearQuantParamLayerWrapper(RangeLinearQuantWrapper):
     """
     def __init__(self, wrapped_module, num_bits_acts, num_bits_params, num_bits_accum=32,
                  mode=LinearQuantMode.SYMMETRIC, clip_acts=ClipMode.NONE, per_channel_wts=False, activation_stats=None,
-                 clip_n_stds=None, clip_half_range=False, scale_approx_mult_bits=None,
-                 input_overrides=None, inputs_quant_auto_fallback=False):
+                 clip_n_stds=None, clip_half_range=False, scale_approx_mult_bits=None, input_overrides=None,
+                 inputs_quant_auto_fallback=False, save_fp_weights=False, also_clip_weights=False):
         super(RangeLinearQuantParamLayerWrapper, self).__init__(wrapped_module, num_bits_acts, num_bits_accum, mode,
                                                                 clip_acts, activation_stats, clip_n_stds, clip_half_range,
                                                                 scale_approx_mult_bits,
@@ -631,20 +736,32 @@ class RangeLinearQuantParamLayerWrapper(RangeLinearQuantWrapper):
         # If activations are not quantized, we do fake quantization of the parameters, that is - quant and de-quant
         self.fake_quant_params = self.output_quant_settings.num_bits is None
 
-        self.wts_quant_settings = QuantSettings(num_bits_params, self.mode.weights, ClipMode.NONE, None, False,
-                                                per_channel_wts)
+        clip_wts_mode, clip_wts_n_stds = ClipMode.NONE, None
+        if also_clip_weights:
+            clip_wts_mode = self.output_quant_settings.clip_mode
+            clip_wts_n_stds = self.output_quant_settings.clip_n_stds
+        self.wts_quant_settings = QuantSettings(num_bits_params, self.mode.weights, clip_wts_mode, clip_wts_n_stds,
+                                                False, per_channel_wts)
 
         self.params_min_q_val, self.params_max_q_val = get_quantized_range(
             self.wts_quant_settings.num_bits,
             self.wts_quant_settings.quant_mode != LinearQuantMode.ASYMMETRIC_UNSIGNED,
             self.wts_quant_settings.quant_mode == LinearQuantMode.SYMMETRIC_RESTRICTED
         )
+        self.save_fp_weights = save_fp_weights
+        # save the float weight to allow re-quantizing
+        if save_fp_weights:
+            wrapped_module.register_buffer('float_weight', wrapped_module.weight.clone().detach())
 
         # Quantize weights - overwrite FP32 weights
         w_scale, w_zero_point = _get_quant_params_from_tensor(wrapped_module.weight,
                                                               self.wts_quant_settings.num_bits,
                                                               self.wts_quant_settings.quant_mode,
-                                                              per_channel=self.wts_quant_settings.per_channel)
+                                                              clip=self.wts_quant_settings.clip_mode,
+                                                              per_channel=self.wts_quant_settings.per_channel,
+                                                              num_stds=self.wts_quant_settings.clip_n_stds)
+        w_scale = w_scale if isinstance(w_scale, torch.Tensor) else torch.tensor(w_scale)
+        w_zero_point = w_zero_point if isinstance(w_zero_point, torch.Tensor) else torch.tensor(w_zero_point)
 
         self.register_buffer('w_scale', w_scale)
         self.register_buffer('w_zero_point', w_zero_point)
@@ -655,17 +772,21 @@ class RangeLinearQuantParamLayerWrapper(RangeLinearQuantWrapper):
         self.has_bias = hasattr(wrapped_module, 'bias') and wrapped_module.bias is not None
         if self.has_bias and (self.fake_quant_params or not self.preset_act_stats):
             b_scale, b_zero_point = _get_quant_params_from_tensor(wrapped_module.bias,
-                                                                  self.wts_quant_settings.num_bits,
-                                                                  self.wts_quant_settings.quant_mode)
+                                                                  self.accum_quant_settings.num_bits,
+                                                                  self.accum_quant_settings.quant_mode)
             self.register_buffer('b_scale', b_scale)
             self.register_buffer('b_zero_point', b_zero_point)
             base_b_q = linear_quantize_clamp(wrapped_module.bias.data, self.b_scale, self.b_zero_point,
-                                             self.params_min_q_val, self.params_max_q_val)
+                                             self.accum_min_q_val, self.accum_max_q_val)
             if not self.preset_act_stats:
                 # Dynamic ranges - save in auxiliary buffer,
                 # requantize each time based on dynamic input scale factor
                 self.register_buffer('base_b_q', base_b_q)
 
+        # allow requantizing the bias:
+        if self.has_bias and self.preset_act_stats:
+            self.register_buffer('fp_bias', self.wrapped_module.bias.data.clone().detach())
+
         # Activations not quantized - de-quant parameters and return
         if self.fake_quant_params:
             linear_dequantize(wrapped_module.weight.data, self.w_scale, self.w_zero_point, inplace=True)
@@ -676,7 +797,7 @@ class RangeLinearQuantParamLayerWrapper(RangeLinearQuantWrapper):
         # Activations are quantized - setup accumulator quantization parameters
         device = self.w_scale.device
         if self.preset_act_stats:
-            t = torch.zeros_like(self.w_scale)
+            t = torch.empty_like(self.w_scale)
             if self.wts_quant_settings.per_channel:
                 t = t.squeeze(dim=-1)
             self.register_buffer('accum_scale', t)
@@ -688,7 +809,70 @@ class RangeLinearQuantParamLayerWrapper(RangeLinearQuantWrapper):
         # and all subsequent calls are done with these shifted weights.
         # Upon calling `self.state_dict()` - we restore the actual quantized weights.
         # i.e. is_simulated_quant_weight_shifted = False
-        self.register_buffer('is_simulated_quant_weight_shifted', torch.tensor(0, dtype=torch.uint8, device=device))
+        self.register_buffer('is_simulated_quant_weight_shifted', torch.tensor(False, device=device))
+
+    def named_linear_quant_params(self, filter=False):
+        if self.save_fp_weights:
+            yield 'w_scale', self.w_scale
+            if not filter or is_linear_quant_mode_asymmetric(self.mode.weights):
+                yield 'w_zero_point', self.w_zero_point
+        yield from super(RangeLinearQuantParamLayerWrapper, self).named_linear_quant_params(filter=filter)
+
+    def set_linear_quant_param(self, name, val):
+        if name in ['w_scale', 'w_zero_point']:
+            if self.save_fp_weights:
+                super().set_linear_quant_param(name, val)
+                self.wrapped_module.weight.data.copy_(self.wrapped_module.float_weight.data)
+                linear_quantize_clamp(self.wrapped_module.weight.data, self.w_scale, self.w_zero_point,
+                                      self.params_min_q_val,
+                                      self.params_max_q_val, inplace=True)
+                if self.fake_quant_params:
+                    linear_dequantize(self.wrapped_module.weight.data, self.w_scale, self.w_zero_point, inplace=True)
+
+            else:
+                raise UnsatisfiedRequirements('Cannot re-quantize the weights. Please specify \'save_fp_weights\' in '
+                                              'the %s constructor to enable re-quantizing the weights.' %
+                                              self.__class__.__name__)
+        else:
+            super().set_linear_quant_param(name, val)
+
+    def _check_requirements_weights_clipping(self, setter=False):
+        if not self.wts_quant_settings.num_bits:
+            raise UnsatisfiedRequirements('Cannot retrieve clipping values because the weights aren\'t quantized.')
+        if setter and not self.save_fp_weights:
+            warnings.warn('Without saving fp32 version of weights, re-quantization is disabled. To enable, '
+                          'please set \'save_fp_weights\' while constructing the wrapper.')
+
+    @property
+    def weight_clipping(self):
+        self._check_requirements_weights_clipping(setter=False)
+        bits, mode = self.wts_quant_settings.num_bits, self.wts_quant_settings.quant_mode
+        scale, zp = self.w_scale, self.w_zero_point
+        return _get_clipping_values(scale, zp, bits, mode)
+
+    @weight_clipping.setter
+    def weight_clipping(self, val):
+        self._check_requirements_weights_clipping(setter=True)
+        bits = self.wts_quant_settings.num_bits
+        val_min, val_max = _check_clipping_val(val, self.wts_quant_settings.quant_mode, False)
+        if is_linear_quant_mode_symmetric(self.wts_quant_settings.quant_mode):
+            # in symmetric quantization - we only need one value
+            scale, zp = symmetric_linear_quantization_params(bits, abs(max(val_min, val_max)))
+        else:
+            signed = self.wts_quant_settings.quant_mode == LinearQuantMode.ASYMMETRIC_SIGNED
+            scale, zp = asymmetric_linear_quantization_params(bits, val_min, val_max, signed=signed)
+        self.set_linear_quant_param('w_scale', scale)
+        self.set_linear_quant_param('w_zero_point', zp)
+
+    def named_clipping(self, filter=False):
+        try:
+            yield from super().named_clipping(filter=filter)
+        except UnsatisfiedRequirements as ex:
+            warnings.warn(str(ex))   # probably the output isn't quantized
+        val = self.weight_clipping
+        if filter and is_linear_quant_mode_symmetric(self.mode.weights):
+            val = val[1]
+        yield 'weight_clipping', val
 
     def state_dict(self, destination=None, prefix='', keep_vars=False):
         if not self.fake_quant_params and self.is_simulated_quant_weight_shifted:
@@ -709,12 +893,15 @@ class RangeLinearQuantParamLayerWrapper(RangeLinearQuantWrapper):
             return accum_scale
 
         if self.preset_act_stats:
-            if self.num_forwards == 0:
-                self.accum_scale += get_accum_scale(input_q)
+            if self.num_forwards == 0 or self.force_readjust:
+                self.accum_scale.copy_(get_accum_scale(input_q))
                 if self.has_bias:
                     # Requantize bias to accumulator scale "permanently"
-                    linear_quantize_clamp(self.wrapped_module.bias.data, self.accum_scale.squeeze(), 0,
-                                          self.accum_min_q_val, self.accum_max_q_val, inplace=True)
+                    self.wrapped_module.bias.data.copy_(
+                        linear_quantize_clamp(self.fp_bias, self.accum_scale.squeeze(), 0,
+                                              self.accum_min_q_val, self.accum_max_q_val)
+                    )
+                self.force_readjust.fill_(False)
         else:
             self.accum_scale = get_accum_scale(input_q)
             if self.has_bias:
@@ -997,7 +1184,7 @@ class FP16Wrapper(FPWrapper):
 
 
 class RangeLinearEmbeddingWrapper(nn.Module):
-    def __init__(self, wrapped_module, num_bits, mode=LinearQuantMode.SYMMETRIC, stats=None):
+    def __init__(self, wrapped_module, num_bits, mode=LinearQuantMode.SYMMETRIC, stats=None, save_fp_weights=False):
         if not isinstance(wrapped_module, nn.Embedding):
             raise ValueError(self.__class__.__name__ + ' can only wrap torch.nn.Embedding modules')
 
@@ -1006,9 +1193,16 @@ class RangeLinearEmbeddingWrapper(nn.Module):
         mode = verify_quant_mode(mode)
         self.mode = mode
 
-        self.min_q_val, self.max_q_val = get_quantized_range(
-            num_bits, signed=mode.weights != LinearQuantMode.ASYMMETRIC_UNSIGNED,
-            signed_restrict_qrange=mode.weights == LinearQuantMode.SYMMETRIC_RESTRICTED)
+        self.wts_quant_settings = QuantSettings(num_bits, self.mode.weights, ClipMode.NONE, None, False, False)
+
+        self.params_min_q_val, self.params_max_q_val = get_quantized_range(
+            self.wts_quant_settings.num_bits,
+            self.wts_quant_settings.quant_mode != LinearQuantMode.ASYMMETRIC_UNSIGNED,
+            self.wts_quant_settings.quant_mode == LinearQuantMode.SYMMETRIC_RESTRICTED
+        )
+        self.save_fp_weights = save_fp_weights
+        if save_fp_weights:
+            wrapped_module.register_buffer('float_weight', wrapped_module.weight.clone().detach())
 
         if stats is None:
             w_scale, w_zero_point = _get_quant_params_from_tensor(wrapped_module.weight, num_bits, mode.weights)
@@ -1016,17 +1210,76 @@ class RangeLinearEmbeddingWrapper(nn.Module):
             w_scale, w_zero_point = _get_quant_params_from_stats_dict(stats['output'], num_bits, mode.weights)
 
         device = wrapped_module.weight.device
-
         self.register_buffer('w_scale', w_scale.to(device))
         self.register_buffer('w_zero_point', w_zero_point.to(device))
-        linear_quantize_clamp(wrapped_module.weight.data, self.w_scale, self.w_zero_point, self.min_q_val,
-                              self.max_q_val, inplace=True)
-        self.quant_metadata = TensorQuantMetadata(self.w_scale, self.w_zero_point, self.min_q_val, self.max_q_val)
+        linear_quantize_clamp(wrapped_module.weight.data, self.w_scale, self.w_zero_point,
+                              self.params_min_q_val, self.params_max_q_val, inplace=True)
+        self.quant_metadata = TensorQuantMetadata(self.w_scale, self.w_zero_point,
+                                                  self.params_min_q_val, self.params_max_q_val)
         self.wrapped_module = wrapped_module
 
-    def named_acts_quant_params(self):
+    def named_linear_quant_params(self, filter=False):
         yield 'w_scale', self.w_scale
-        yield 'w_zero_point', self.w_zero_point
+        if not filter or is_linear_quant_mode_asymmetric(self.mode.weights):
+            yield 'w_zero_point', self.w_zero_point
+
+    def set_linear_quant_param(self, name, val):
+        if name in ['w_scale', 'w_zero_point']:
+            if self.save_fp_weights:
+                getattr(self, name).fill_(val)
+                self.wrapped_module.weight.data.copy_(self.wrapped_module.float_weight.data)
+                linear_quantize_clamp(self.wrapped_module.weight.data, self.w_scale, self.w_zero_point,
+                                      self.params_min_q_val,
+                                      self.params_max_q_val, inplace=True)
+            else:
+                raise UnsatisfiedRequirements('Cannot re-quantize the weights. Please specify \'save_fp_weights\' in '
+                                              'the %s constructor to enable re-quantizing the weights.' %
+                                              self.__class__.__name__)
+        else:
+            raise KeyError('No quantization parameter called \'%s\'.' % name)
+
+    def update_linear_quant_params(self, new_config):
+        """
+        Updates all the quant params using a dictionary.
+        Args:
+             new_config (dict): the new configuration dict.
+        """
+        for name, val in new_config.items():
+            self.set_linear_quant_param(name, val)
+
+    def _check_requirements_weights_clipping(self, setter=False):
+        if not self.wts_quant_settings.num_bits:
+            raise UnsatisfiedRequirements('Cannot retrieve clipping values because the weights aren\'t quantized.')
+        if setter and not self.save_fp_weights:
+            warnings.warn('Without saving fp32 version of weights, re-quantization is disabled. To enable, '
+                          'please set \'save_fp_weights\' while constructing the wrapper.')
+
+    @property
+    def weight_clipping(self):
+        self._check_requirements_weights_clipping(setter=False)
+        bits, mode = self.wts_quant_settings.num_bits, self.wts_quant_settings.quant_mode
+        scale, zp = self.w_scale, self.w_zero_point
+        return _get_clipping_values(scale, zp, bits, mode)
+
+    @weight_clipping.setter
+    def weight_clipping(self, val):
+        self._check_requirements_weights_clipping(setter=True)
+        bits = self.wts_quant_settings.num_bits
+        val_min, val_max = _check_clipping_val(val, self.wts_quant_settings.quant_mode, False)
+        if is_linear_quant_mode_symmetric(self.wts_quant_settings.quant_mode):
+            # in symmetric quantization - we only need one value
+            scale, zp = symmetric_linear_quantization_params(bits, val_max)
+        else:
+            signed = self.wts_quant_settings.quant_mode == LinearQuantMode.ASYMMETRIC_SIGNED
+            scale, zp = asymmetric_linear_quantization_params(bits, val_min, val_max, signed=signed)
+        self.set_linear_quant_param('w_scale', scale)
+        self.set_linear_quant_param('w_zero_point', zp)
+
+    def named_clipping(self, filter=False):
+        val = self.weight_clipping
+        if filter and is_linear_quant_mode_symmetric(self.mode.weights):
+            val = val[1]
+        yield 'weight_clipping', val
 
     def forward(self, input):
         out_q = self.wrapped_module(input)
@@ -1038,7 +1291,14 @@ class RangeLinearEmbeddingWrapper(nn.Module):
 class RangeLinearFakeQuantWrapper(RangeLinearQuantWrapper):
     def __init__(self, wrapped_module, num_bits_acts, mode=LinearQuantMode.SYMMETRIC, clip_acts=ClipMode.NONE,
                  activation_stats=None, clip_n_stds=None, clip_half_range=False, scale_approx_mult_bits=None,
-                 fpq_module=None, input_overrides=None, inputs_quant_auto_fallback=False):
+                 fpq_module=None, input_overrides=None, inputs_quant_auto_fallback=False, quantize_inputs=False):
+        if isinstance(wrapped_module, (nn.ReLU, nn.ReLU6)):
+            # In case of ReLU + Gauss/Laplace clipping, need to clip according to stats before ReLU is applied
+            clip_half_range = True
+            if clip_acts in (ClipMode.GAUSS, ClipMode.LAPLACE):
+                activation_stats['output']['mean'] = activation_stats['inputs'][0]['mean']
+                activation_stats['output']['std'] = activation_stats['inputs'][0]['std']
+                activation_stats['output']['b'] = activation_stats['inputs'][0]['b']
         super(RangeLinearFakeQuantWrapper, self).__init__(wrapped_module, num_bits_acts, mode=mode,
                                                           clip_acts=clip_acts, activation_stats=activation_stats,
                                                           clip_n_stds=clip_n_stds, clip_half_range=clip_half_range,
@@ -1047,11 +1307,15 @@ class RangeLinearFakeQuantWrapper(RangeLinearQuantWrapper):
                                                           inputs_quant_auto_fallback=inputs_quant_auto_fallback)
         self.fpq_module = str(fpq_module) if fpq_module else None
         self.dtype = torch.float
+        self.quantize_inputs = quantize_inputs
         if self.fpq_module:
             self.dtype = {'16': torch.half, '32': torch.float, '64': torch.double}[self.fpq_module]
             self.wrapped_module.to(self.dtype)
 
     def _prepare_input(self, idx, input):
+        if not self.quantize_inputs:
+            return input
+
         previously_quantized = hasattr(input, 'quant_metadata')
         input.quant_metadata = self._get_input_quant_metadata(idx, input)
         if previously_quantized:
@@ -1143,6 +1407,9 @@ class PostTrainLinearQuantizer(Quantizer):
             For details what this does and how to override it.
         fpq_module (Union[int, str]): use the modules in floating point mode and only quantize their outputs.
             takes the values (16, 32, 64) only, this will use RangeLinearFakeQuantWrapper.
+        save_fp_weights (bool): Indicates whether or not to save a copy of the floating point weights.
+          This allows re-quantization of weight after the initial quantization.
+          Defaults to False for performance.
     Note:
         If fpq_module is set, all the layers (except those overridden in `overrides`) will be converted
         to the set floating point precision, regardless of bits_activations/parameters/accum.
@@ -1152,7 +1419,7 @@ class PostTrainLinearQuantizer(Quantizer):
                  per_channel_wts=False, model_activation_stats=None, fp16=False,
                  clip_n_stds=None, clip_half_range=False,
                  scale_approx_mult_bits=None, inputs_quant_auto_fallback=True,
-                 fpq_module=None):
+                 fpq_module=None, save_fp_weights=False, also_clip_weights=False):
         overrides_bkp = deepcopy(overrides)
         super(PostTrainLinearQuantizer, self).__init__(model, bits_activations=bits_activations,
                                                        bits_weights=bits_parameters, bits_bias=bits_accum,
@@ -1208,7 +1475,7 @@ class PostTrainLinearQuantizer(Quantizer):
 
         def replace_param_layer(module, name, qbits_map, per_channel_wts=per_channel_wts,
                                 mode=mode, fp16=fp16, scale_approx_mult_bits=scale_approx_mult_bits,
-                                clip_acts=clip_acts, clip_n_stds=clip_n_stds, clip_half_range=clip_half_range,
+                                clip_acts=None, clip_n_stds=clip_n_stds, clip_half_range=clip_half_range,
                                 input_overrides=None, fpq_module=fpq_module, fake=False):
             fpq_module = _check_fp16_arg(fp16, fpq_module)
             if fpq_module and not fake:
@@ -1216,7 +1483,7 @@ class PostTrainLinearQuantizer(Quantizer):
 
             norm_name = distiller.utils.normalize_module_name(name)
             activation_stats = self.model_activation_stats.get(norm_name, None)
-            clip_acts = verify_clip_mode(clip_acts)
+            clip_acts = verify_clip_mode(clip_acts or self.clip_acts)
             qbits = qbits_map[name]
             if qbits.acts is not None and qbits.wts is None:
                 # Quantizing only activations equals fake-quantization
@@ -1228,7 +1495,8 @@ class PostTrainLinearQuantizer(Quantizer):
                                                    clip_n_stds=clip_n_stds, clip_half_range=clip_half_range,
                                                    scale_approx_mult_bits=scale_approx_mult_bits,
                                                    fpq_module=fpq_module, input_overrides=input_overrides,
-                                                   inputs_quant_auto_fallback=inputs_quant_auto_fallback)
+                                                   inputs_quant_auto_fallback=inputs_quant_auto_fallback,
+                                                   quantize_inputs=False)
 
             return RangeLinearQuantParamLayerWrapper(module, qbits.acts, qbits.wts,
                                                      num_bits_accum=self.bits_accum, mode=mode, clip_acts=clip_acts,
@@ -1237,11 +1505,13 @@ class PostTrainLinearQuantizer(Quantizer):
                                                      clip_n_stds=clip_n_stds, clip_half_range=clip_half_range,
                                                      scale_approx_mult_bits=scale_approx_mult_bits,
                                                      input_overrides=input_overrides,
-                                                     inputs_quant_auto_fallback=inputs_quant_auto_fallback)
+                                                     inputs_quant_auto_fallback=inputs_quant_auto_fallback,
+                                                     save_fp_weights=self.save_fp_weights,
+                                                     also_clip_weights=self.also_clip_weights)
 
         def replace_non_param_layer(wrapper_type, module, name, qbits_map, fp16=fp16,
                                     scale_approx_mult_bits=scale_approx_mult_bits,
-                                    clip_acts=clip_acts, clip_n_stds=clip_n_stds, clip_half_range=clip_half_range,
+                                    clip_acts=None, clip_n_stds=clip_n_stds, clip_half_range=clip_half_range,
                                     input_overrides=None, inputs_quant_auto_fallback=inputs_quant_auto_fallback,
                                     fpq_module=fpq_module, fake=False):
             fpq_module = _check_fp16_arg(fp16, fpq_module)
@@ -1250,7 +1520,7 @@ class PostTrainLinearQuantizer(Quantizer):
 
             norm_name = distiller.utils.normalize_module_name(name)
             activation_stats = self.model_activation_stats.get(norm_name, None)
-            clip_acts = verify_clip_mode(clip_acts)
+            clip_acts = verify_clip_mode(clip_acts or self.clip_acts)
             qbits = qbits_map[name]
 
             if fake:
@@ -1259,7 +1529,8 @@ class PostTrainLinearQuantizer(Quantizer):
                                                    clip_n_stds=clip_n_stds, clip_half_range=clip_half_range,
                                                    scale_approx_mult_bits=scale_approx_mult_bits,
                                                    fpq_module=fpq_module, input_overrides=input_overrides,
-                                                   inputs_quant_auto_fallback=inputs_quant_auto_fallback)
+                                                   inputs_quant_auto_fallback=inputs_quant_auto_fallback,
+                                                   quantize_inputs=False)
             try:
                 return wrapper_type(module, qbits.acts, mode=mode, clip_acts=clip_acts,
                                     activation_stats=activation_stats,
@@ -1281,10 +1552,10 @@ class PostTrainLinearQuantizer(Quantizer):
                                                stats=self.model_activation_stats.get(norm_name, None))
 
         def replace_fake_quant(module, name, qbits_map, fp16=fp16,
-                               clip_acts=clip_acts, clip_n_stds=clip_n_stds, clip_half_range=clip_half_range,
+                               clip_acts=None, clip_n_stds=clip_n_stds, clip_half_range=clip_half_range,
                                scale_approx_mult_bits=scale_approx_mult_bits, input_overrides=None,
                                inputs_quant_auto_fallback=inputs_quant_auto_fallback,
-                               fpq_module=fpq_module, fake=True, make_identity=False):
+                               fpq_module=fpq_module, fake=True, make_identity=False, quantize_inputs=True):
             if isinstance(module, (nn.ReLU, nn.ReLU6)) and make_identity:
                 named_modules = OrderedDict(self.model.named_modules())
                 pred = self.adjacency_map[name].predecessors[0].name
@@ -1299,19 +1570,22 @@ class PostTrainLinearQuantizer(Quantizer):
                 return FPWrapper(module, fpq_module)
 
             norm_name = distiller.utils.normalize_module_name(name)
-            clip_acts = verify_clip_mode(clip_acts)
+            clip_acts = verify_clip_mode(clip_acts or self.clip_acts)
             return RangeLinearFakeQuantWrapper(module, qbits_map[name].acts, mode=mode, clip_acts=clip_acts,
                                                activation_stats=self.model_activation_stats.get(norm_name, None),
                                                clip_n_stds=clip_n_stds,  clip_half_range=clip_half_range,
                                                scale_approx_mult_bits=scale_approx_mult_bits,
                                                fpq_module=fpq_module, input_overrides=input_overrides,
-                                               inputs_quant_auto_fallback=inputs_quant_auto_fallback)
+                                               inputs_quant_auto_fallback=inputs_quant_auto_fallback,
+                                               quantize_inputs=quantize_inputs)
 
         self.clip_acts = clip_acts
         self.clip_n_stds = clip_n_stds
         self.model_activation_stats = model_activation_stats or {}
         self.bits_accum = bits_accum
         self.mode = mode
+        self.save_fp_weights = save_fp_weights
+        self.also_clip_weights = also_clip_weights
 
         self.replacement_factory[nn.Conv2d] = replace_param_layer
         self.replacement_factory[nn.Conv3d] = replace_param_layer
@@ -1341,34 +1615,90 @@ class PostTrainLinearQuantizer(Quantizer):
         self.default_repalcement_fn = replace_fake_quant
         self.replacement_blacklist.append(nn.Dropout)
 
-        save_dir = msglogger.logdir if hasattr(msglogger, 'logdir') else '.'
-        self.save_per_layer_parameters(save_dir)
+        # To be filled by .prepare_model()
+        self.linear_quant_params = None
 
-    def named_acts_quant_params(self):
+    def named_linear_quant_params(self, yield_clipping_params=False, filter=False):
+        if yield_clipping_params:
+            yield from self.named_clipping(filter=filter)
+            return
         for module_name, module in self.model.named_modules():
             if is_post_train_quant_wrapper(module, include_fpwrapper=False):
-                for buff_name, buff in module.named_acts_quant_params():
+                for buff_name, buff in module.named_linear_quant_params(filter=filter):
                     full_buff_name = "%s.%s" % (module_name, buff_name)
                     yield full_buff_name, buff
 
-    def set_act_quant_param(self, name, val):
+    def named_clipping(self, filter=False):
+        """
+        Gets all the clipping parameters of the model.
+        yields tuple[str, tuple[torch.Tensor, torch.Tensor]]
+        """
+        for module_name, module in self.model.named_modules():
+            if not is_post_train_quant_wrapper(module, include_fpwrapper=False):
+                continue
+            for clip_name, clip_val in module.named_clipping(filter=filter):  # type: str, tuple[torch.Tensor, torch.Tensor]
+                yield '%s.%s' % (module_name, clip_name), clip_val
+
+    def set_clipping(self, name, val):
+        """
+        Sets a clipping parameter by name.
+        Args:
+            name (str): the name of the clipping parameter.
+            val (tuple[float or torch.Tensor, float or torch.Tensor]): the value of the clipping.
+        """
+        module_name = distiller.param_name_2_module_name(name)
+        clip_name = name.split('.')[-1]
+        module = dict(self.model.named_modules())[module_name]
+        if not is_post_train_quant_wrapper(module, False):
+            raise ValueError('\'%s\' isn\'t a wrapper and has no clipping parameters.' % module_name)
+        if clip_name not in dict(module.named_clipping()):
+            raise ValueError('\'%s\' is not a clipping parameter.' % clip_name)
+        setattr(module, clip_name, val)
+
+    def update_clipping_parameters(self, clipping_config):
+        """
+        Updates all clipping paramters according to a configuration dict.
+        Args:
+            clipping_config (dict[str, tuple[float or torch.Tensor, float or torch.Tensor]]):
+              the clipping configuration.
+        """
+        for name, val in clipping_config.items():
+            self.set_clipping(name, val)
+
+    def _is_clipping_parameter(self, name):
+        module_name = distiller.param_name_2_module_name(name)
+        clip_name = name.split('.')[-1]
+        module = dict(self.model.named_modules())[module_name]
+        return is_post_train_quant_wrapper(module, False) and clip_name in dict(module.named_clipping())
+
+    def force_readjust_wrappers(self):
+        def _force_readjust(module):
+            if isinstance(module, RangeLinearQuantWrapper):
+                module.force_readjust.fill_(True)
+        self.model.apply(_force_readjust)
+
+    def set_linear_quant_param(self, name, val):
         """
         Sets the the quant parameter by module_name.quant_param_name.
+        Can also set the clipping values.
         Args:
              name (str): the name of the quant param [module_name].[quant_param_name]
-             val (int or float or torch.Tensor): the new value.
+             val: the new value.
         """
-        self.acts_quant_params[name].fill_(val)
+        if self._is_clipping_parameter(name):
+            self.set_clipping(name, val)
+        else:
+            self.linear_quant_params[name].data.fill_(val)
+        self.force_readjust_wrappers()
 
-    def update_acts_quant_params(self, new_config):
+    def update_linear_quant_params(self, new_config):
         """
         Updates all the quant params using a dictionary.
         Args:
              new_config (dict): the new configuration dict.
         """
         for k, v in new_config.items():
-            self.set_act_quant_param(k, v)
-
+            self.set_linear_quant_param(k, v)
 
     @classmethod
     def from_args(cls, model, args):
@@ -1385,9 +1715,15 @@ class PostTrainLinearQuantizer(Quantizer):
             if args.qe_bits_wts == 0:
                 args.qe_bits_wts = None
             overrides = OrderedDict(
-                [(layer, OrderedDict([('clip_acts', 'NONE')]))
-                 for layer in args.qe_no_clip_layers]
+                [
+                    (layer, OrderedDict([('bits_activations', None), ('bits_weights', None)]))
+                    for layer in args.qe_no_quant_layers
+                ]
             )
+            overrides.update(OrderedDict(
+                [(layer, OrderedDict([('clip_acts', 'NONE')]))
+                 for layer in args.qe_no_clip_layers if layer not in args.qe_no_quant_layers]
+            ))
             mode_acts = args.qe_mode_acts or args.qe_mode
             mode_wts = args.qe_mode_wts or args.qe_mode
             mode = ModuleQuantMode(mode_acts, mode_wts)
@@ -1402,7 +1738,8 @@ class PostTrainLinearQuantizer(Quantizer):
                        clip_n_stds=args.qe_clip_n_stds,
                        scale_approx_mult_bits=args.qe_scale_approx_bits,
                        overrides=overrides,
-                       inputs_quant_auto_fallback=True)
+                       inputs_quant_auto_fallback=True,
+                       save_fp_weights=args.qe_save_fp_weights)
 
     def save_per_layer_parameters(self, save_dir=''):
         defaults = OrderedDict(self.model.quantizer_metadata['params'])
@@ -1410,8 +1747,10 @@ class PostTrainLinearQuantizer(Quantizer):
         defaults.pop('bits_parameters')
         defaults.pop('bits_accum')
         out = OrderedDict()
-        for n, m in self.model.named_modules():
-            if distiller.has_children(m):
+        for n in self.module_overrides_map:
+            modules_dict = dict(self.model.named_modules())
+            m = modules_dict[n]
+            if distiller.has_children(m) and not is_post_train_quant_wrapper(m, include_fpwrapper=False):
                 continue
             qbits = self.module_qbits_map[n]
             d = OrderedDict()
@@ -1422,6 +1761,11 @@ class PostTrainLinearQuantizer(Quantizer):
                 actual_v = self.module_overrides_map[n].get(k, v)
                 d[k] = actual_v
             out[n] = d
+        if self.linear_quant_params:
+            out['linear_quant_params'] = lqp_dict = OrderedDict()
+            for k, v in self.linear_quant_params.items():  # type: str, torch.Tensor
+                lqp_dict[k] = v.item()
+
         save_path = os.path.join(save_dir, 'layer_quant_params.yaml')
         distiller.yaml_ordered_save(save_path, out)
         msglogger.info('Per-layer quantization parameters saved to ' + save_path)
@@ -1438,10 +1782,12 @@ class PostTrainLinearQuantizer(Quantizer):
             # Setting dummy_input to None to make sure SummaryGraph won't be called
             dummy_input = None
         elif dummy_input is None:
-            raise ValueError('PostTrainLinearQuantizer requires dummy input in order to perform certain optimizations')
+            raise UnsatisfiedRequirements('PostTrainLinearQuantizer requires dummy '
+                                          'input in order to perform certain optimizations')
         super(PostTrainLinearQuantizer, self).prepare_model(dummy_input)
 
-        self.acts_quant_params = OrderedDict(self.named_acts_quant_params())
+        save_dir = msglogger.logdir if hasattr(msglogger, 'logdir') else '.'
+        self.save_per_layer_parameters(save_dir)
 
     def _pre_prepare_model(self, dummy_input):
         if not self.has_bidi_distiller_lstm:
@@ -1519,7 +1865,11 @@ class PostTrainLinearQuantizer(Quantizer):
         named_modules = OrderedDict(self.model.named_modules())
         model_stats = self.model_activation_stats
         for n, m in named_modules.items():
-            if (distiller.has_children(m) and not isinstance(m, SimulatedFoldedBatchNorm) )\
+            # Don't fuse if module outputs aren't quantized:
+            qbits = self.module_qbits_map.get(n, QBits(None, None, None))
+            if qbits.acts is None:
+                continue
+            if (distiller.has_children(m) and not isinstance(m, SimulatedFoldedBatchNorm))\
                     or n not in self.adjacency_map or len(self.adjacency_map[n].successors) != 1:
                 continue
             successor = self.adjacency_map[n].successors[0]
@@ -1570,12 +1920,11 @@ class PostTrainLinearQuantizer(Quantizer):
     def _apply_fuse_relu(self):
         """Fuses ReLU layers to the linear layers before them."""
         model_overrides = self.module_overrides_map
-        qbits_map = self.module_qbits_map
         named_modules = dict(self.model.named_modules())
         for n, m in named_modules.items():
-            # Don't fuse if the module isn't quantized:
-            qbits = qbits_map.get(n, QBits(None, None, None))
-            if qbits.acts is None and qbits.wts is None:
+            # Don't fuse if module outputs aren't quantized:
+            qbits = self.module_qbits_map.get(n, QBits(None, None, None))
+            if qbits.acts is None:
                 continue
             if (distiller.has_children(m) and not isinstance(m, SimulatedFoldedBatchNorm))\
                     or n not in self.adjacency_map or len(self.adjacency_map[n].successors) != 1:
@@ -1602,14 +1951,13 @@ class PostTrainLinearQuantizer(Quantizer):
             self._clip_stats(self.model_activation_stats[name]['output'], -sat_val, sat_val)
 
     def _post_prepare_model(self):
-        if isinstance(self.model, nn.DataParallel):
-            # We restore the buffers to the master-GPU of the modules:
-            device = self.model.src_device_obj
-            m = self.model.module
-            for param in m.parameters():
-                param.data = param.data.to(device)
-            for buffer in m.buffers():
-                buffer.data = buffer.data.to(device)
+        m = self.model
+        device = distiller.model_device(m)
+        for param in m.parameters():
+            param.data = param.data.to(device)
+        for buffer in m.buffers():
+            buffer.data = buffer.data.to(device)
+        self.linear_quant_params = OrderedDict(self.named_linear_quant_params())
 
 
 ###############################################################################
diff --git a/distiller/utils.py b/distiller/utils.py
index 3c5c6b3..2af8be9 100755
--- a/distiller/utils.py
+++ b/distiller/utils.py
@@ -40,6 +40,8 @@ msglogger = logging.getLogger()
 def model_device(model):
     """Determine the device the model is allocated on."""
     # Source: https://discuss.pytorch.org/t/how-to-check-if-model-is-on-cuda/180
+    if isinstance(model, nn.DataParallel):
+        return model.src_device_obj
     try:
         return str(next(model.parameters()).device)
     except StopIteration:
@@ -785,4 +787,10 @@ def model_setattr(model, attr_name, val, register=False):
 
 
 def param_name_2_module_name(param_name):
-    return '.'.join(param_name.split('.')[:-1])
\ No newline at end of file
+    return '.'.join(param_name.split('.')[:-1])
+
+
+def is_scalar(val):
+    result = isinstance(val, torch.Tensor) and val.dim() == 0
+    result |= np.isscalar(val)
+    return result
diff --git a/examples/quantization/post_train_quant/resnet18_imagenet_post_train_lapq.yaml b/examples/quantization/post_train_quant/resnet18_imagenet_post_train_lapq.yaml
new file mode 100644
index 0000000..4616d21
--- /dev/null
+++ b/examples/quantization/post_train_quant/resnet18_imagenet_post_train_lapq.yaml
@@ -0,0 +1,101 @@
+# Post-training quantization settings for running in the same conditions as in:
+# Nahshan et al., "Loss Aware Post-training Quantization" (https://arxiv.org/abs/1911.07190), according to the
+# reference implementation found at: https://github.com/ynahshan/nn-quantization-pytorch/tree/master/lapq
+#
+# The settings are:
+#   * Only fake-quantization is done
+#   * Only weights of convolutions and outputs of ReLU are quantized. Pooling, element-wise addition and FC layers
+#     are not quantized
+#   * The first convolution + relu pair isn't quantized
+#   * The last convolution + relu pair ("layer4.1.conv2/relu2") isn't quantized
+#
+# See example invocations and results after the YAML definition
+
+quantizers:
+  post_train_quantizer:
+    class: PostTrainLinearQuantizer
+    # Don't quantize anything by default, override below for should be quantized
+    bits_activations: null
+    bits_parameters: null
+    bits_accum: 32
+    mode:
+      activations: ASYMMETRIC_UNSIGNED
+      weights: SYMMETRIC
+    model_activation_stats: ../../examples/quantization/post_train_quant/stats/resnet18_quant_stats.yaml
+    per_channel_wts: False
+    inputs_quant_auto_fallback: True
+
+    overrides:
+      # Conv layers inside the ResNet BasicBlock are quantized (except last one, see below)
+      layer.*conv.*:
+        bits_activations: null
+        bits_weights: 4
+      # ReLU layers inside the ResNet BasicBlock are quantized (except last one, see below)
+      layer.*relu.*:
+        bits_activations: 4
+        quantize_inputs: False
+      # Conv layers in downsampling residual connections are quantized
+      .*downsample\.0.*:
+        bits_activations: null
+        bits_weights: 4
+      # The last conv+relu layers are NOT quantized, we specify them directly
+      layer4.1.conv2:
+        bits_activations: null
+        bits_weights: null
+      layer4.1.relu2:
+        bits_activations: null
+
+# Example invocations:
+#   * Preliminaries:
+#       cd <distiller_root>/distiller/quantization
+#       CONFIG_FILE="<distiller_root>/examples/quantization/post_train_quant/resnet18_imagenet_post_train_lapq.yaml"
+#
+#   * Using L3 initialization:
+#     Command:
+#       python ptq_coordinate_search.py -a resnet18 --pretrained <path_to_imagenet> --opt-val-size 0.01 --opt-maxiter 2 --qe-config-file $CONFIG_FILE -b 500 --opt-init-mode L3 --opt-init-method powell --opt-eval-memoize-dataloader --det --opt-search-clipping
+#
+#     Excerpts from output:
+#       ...
+#       Initializing quantizer...
+#       Initializing quantization parameters...
+#       ...
+#       Evaluating initial quantization score...
+#       Evaluation set loss after initialization 2.522
+#       Test:     loss=2.650, top1=43.816, top5=68.840
+#       Using "Powell" minimization algorithm.
+#       ...
+#       980 evaluations: loss=1.962
+#       Iteration 0:   Score=1.956
+#       Test:     loss=2.117, top1=51.662, top5=76.606
+#       ...
+#       2200 evaluations: loss=1.929
+#       Iteration 1:   Score=1.926
+#       Test:     loss=2.116, top1=51.784, top5=76.712
+#       Optimization Done.
+#       Arch: resnet18    Test:    top1 = 51.784   top5 = 76.712   loss = 2.116
+#
+#
+#   * Using LAPLACE initialization:
+#     Command:
+#       python ptq_coordinate_search.py -a resnet18 --pretrained <path_to_imagenet> --opt-val-size 0.01 --opt-maxiter 2 --qe-config-file $CONFIG_FILE -b 500 --opt-init-mode LAPLACE --opt-init-method powell --opt-eval-memoize-dataloader --det --opt-search-clipping
+#
+#     Excerpts from output:
+#       ...
+#       Initializing quantizer...
+#       Initializing quantization parameters...
+#       ...
+#       Evaluating initial quantization score...
+#       Evaluation set loss after initialization 3.376
+#       Evaluating on full test set...
+#       Test:     loss=3.509, top1=29.492, top5=53.768
+#       Using "Powell" minimization algorithm.
+#       ...
+#       620 evaluations: loss=2.458
+#       Iteration 0:   Score=2.458
+#       Test:     loss=2.650, top1=42.700, top5=68.138
+#       ...
+#       1780 evaluations: loss=2.277
+#       Iteration 1:   Score=2.274
+#       Test:     loss=2.504, top1=45.164, top5=70.400
+#       Optimization Done.
+#       Arch: resnet18    Test:    top1 = 45.164   top5 = 70.400   loss = 2.504
diff --git a/tests/test_post_train_quant.py b/tests/test_post_train_quant.py
index fbd5d14..95a0295 100644
--- a/tests/test_post_train_quant.py
+++ b/tests/test_post_train_quant.py
@@ -178,13 +178,13 @@ def linear_bias():
     "mode, clip_acts, per_channel_wts, expected_output",
     [
         (LinearQuantMode.ASYMMETRIC_UNSIGNED, ClipMode.NONE, False,
-         torch.tensor([[7.686200692, 0.241135708, 0.783691051]], dtype=torch.float32)),
+         torch.tensor([[7.687381776, 0.241172762, 0.783811475]], dtype=torch.float32)),
         (LinearQuantMode.ASYMMETRIC_UNSIGNED, ClipMode.NONE, True,
-         torch.tensor([[7.698823529, 0.241531719, 0.784978085]], dtype=torch.float32)),
+         torch.tensor([[7.699930796, 0.241566456, 0.785090983]], dtype=torch.float32)),
         (LinearQuantMode.SYMMETRIC, ClipMode.NONE, False,
-         torch.tensor([[7.718753843, 0.243110357, 0.790108661]], dtype=torch.float32)),
+         torch.tensor([[7.716609268, 0.243042812, 0.789889138]], dtype=torch.float32)),
         (LinearQuantMode.SYMMETRIC, ClipMode.NONE, True,
-         torch.tensor([[7.718753843, 0.243110357, 0.790108661]], dtype=torch.float32))
+         torch.tensor([[7.716609268, 0.243042812, 0.789889138]], dtype=torch.float32))
     ]
 )
 def test_linear_layer_wrapper(linear_input, linear_weights, linear_bias,
@@ -199,8 +199,8 @@ def test_linear_layer_wrapper(linear_input, linear_weights, linear_bias,
     linear_input = attach_quant_metadata(linear_input, 8, mode, stats=None, clip_mode=clip_acts,
                                          per_channel=False, num_stds=None, scale_approx_mult_bits=None)
 
-    with pytest.raises(RuntimeError):
-        model(linear_input)
+    # with pytest.raises(RuntimeError):
+    #     model(linear_input)
 
     model.eval()
 
@@ -688,34 +688,42 @@ def test_acts_quant_params_linear(act1_type, act2_type, bn_out_stats):
     model = LinearBNSplitAct(act1_type, act2_type)
     stats = gen_stats_for_model(model)
     stats['bn']['output'] = bn_out_stats
-    quantizer = PostTrainLinearQuantizer(model, model_activation_stats=deepcopy(stats))
+    quantizer = PostTrainLinearQuantizer(model, model_activation_stats=deepcopy(stats), save_fp_weights=True)
     quantizer.prepare_model(torch.randn(10, 10))
     # get quant params:
     expected_quant_params_keys = {
         'linear.output_zero_point',
         'linear.output_scale',
+        'linear.w_scale',
+        'linear.w_zero_point',
         'act1.output_zero_point',
         'act1.output_scale',
         'act2.output_zero_point',
         'act2.output_scale'
     }
-    assert set(quantizer.acts_quant_params) == expected_quant_params_keys
-    quantizer.set_act_quant_param('linear.output_zero_point', 2.)
-    quantizer.set_act_quant_param('linear.output_scale', 30.)
+    assert set(quantizer.linear_quant_params) == expected_quant_params_keys
+    quantizer.set_linear_quant_param('linear.output_zero_point', 2.)
+    quantizer.set_linear_quant_param('linear.output_scale', 30.)
     assert model.linear.output_zero_point == 2.
     assert model.linear.output_scale == 30.
+    assert model.linear.force_readjust == True
+    assert model.act1.force_readjust == True
     expected_quant_param_linear_dict = {
         'output_zero_point': torch.tensor(2.),
-        'output_scale': 30.
+        'output_scale': 30.,
+        'w_scale': model.linear.w_scale.item(),
+        'w_zero_point': model.linear.w_zero_point.item()
     }
-    assert dict(model.linear.named_acts_quant_params()) == expected_quant_param_linear_dict
+    assert dict(model.linear.named_linear_quant_params()) == expected_quant_param_linear_dict
     new_config = {
         'linear.output_zero_point': 4.,
         'act2.output_scale': 50
     }
-    quantizer.update_acts_quant_params(new_config)
+    quantizer.update_linear_quant_params(new_config)
     assert model.linear.output_zero_point == 4
     assert model.act2.output_scale == 50
+    assert model.linear.force_readjust == True
+    assert model.act1.force_readjust == True
 
 
 class DummyWordLangModel(nn.Module):
@@ -732,18 +740,20 @@ class DummyWordLangModel(nn.Module):
 @pytest.mark.filterwarnings('ignore:Iterating over a tensor might cause the trace to be incorrect')
 @pytest.mark.filterwarnings('ignore:Converting a tensor to a Python index might cause the trace to be incorrect')
 def test_acts_quant_params_rnn(rnn_model):
-    model = DummyWordLangModel(nn.Embedding(41, 20), rnn_model).cuda()
+    model = DummyWordLangModel(nn.Embedding(41, 20), rnn_model)
     stats = gen_stats_for_model(model)
     quantizer = PostTrainLinearQuantizer(model, model_activation_stats=deepcopy(stats))
-    dummy_input = torch.randint(0, 41, size=(79, 23))
+    dummy_input = torch.randint(0, 41, size=(10, 1))
     quantizer.prepare_model(dummy_input)
     new_config = {
         'rnn.rnn.cells.0.act_o.output_scale': 4,
         'embedding.w_scale': torch.tensor(59.0)
     }
-    quantizer.update_acts_quant_params(new_config)
+    quantizer.update_linear_quant_params(new_config)
     assert model.rnn.rnn.cells[0].act_o.output_scale == 4
     assert model.embedding.w_scale == 59.0
+    assert model.rnn.rnn.cells[0].act_o.force_readjust.item() is True
+    assert model.rnn.rnn.cells[0].act_f.force_readjust.item() is True
 
 
 ###############################################################################
@@ -767,20 +777,21 @@ def _fake_quant_tensor(tensor, n_bits, mode, per_channel):
     q_utils.linear_dequantize(tensor, scale, zp, inplace=True)
 
 
-def _test_wts_only_quant(layer, x, per_channel, bias, num_bits):
+def _test_wts_only_quant(layer, x, per_channel, bias, num_bits_wts, num_bits_accum):
     layer.weight.data = torch.rand_like(layer.weight)
     if bias:
         layer.bias.data = torch.rand_like(layer.bias)
     mode = LinearQuantMode.ASYMMETRIC_UNSIGNED
 
-    layer_ptq = RangeLinearQuantParamLayerWrapper(deepcopy(layer), None, num_bits, mode=mode, per_channel_wts=per_channel)
+    layer_ptq = RangeLinearQuantParamLayerWrapper(deepcopy(layer), None, num_bits_wts, num_bits_accum=num_bits_accum,
+                                                  mode=mode, per_channel_wts=per_channel)
     layer_ptq.eval()
 
     layer_manual_q = deepcopy(layer)
-    _fake_quant_tensor(layer_manual_q.weight.data, num_bits, mode, per_channel)
+    _fake_quant_tensor(layer_manual_q.weight.data, num_bits_wts, mode, per_channel)
     assert torch.equal(layer_ptq.wrapped_module.weight, layer_manual_q.weight)
     if bias:
-        _fake_quant_tensor(layer_manual_q.bias.data, num_bits, mode, False)
+        _fake_quant_tensor(layer_manual_q.bias.data, num_bits_accum, mode, False)
         assert torch.equal(layer_ptq.wrapped_module.bias, layer_manual_q.bias)
 
     y_ptq = layer_ptq(x)
@@ -795,7 +806,7 @@ def test_conv_layer_wrapper_params_only(per_channel, bias):
     layer = torch.nn.Conv2d(in_ch, 10, 3, bias=bias)
     x = torch.rand(5, in_ch, 5, 5)
 
-    _test_wts_only_quant(layer, x, per_channel, bias, 8)
+    _test_wts_only_quant(layer, x, per_channel, bias, 8, 32)
 
 
 def test_linear_layer_wrapper_params_only(per_channel, bias):
@@ -805,4 +816,4 @@ def test_linear_layer_wrapper_params_only(per_channel, bias):
 
     x = torch.rand(5, in_features)
 
-    _test_wts_only_quant(layer, x, per_channel, bias, 8)
+    _test_wts_only_quant(layer, x, per_channel, bias, 8, 32)
-- 
GitLab