From 394e3bc616fcf3548e0fe486bbf6e5170479c8dc Mon Sep 17 00:00:00 2001
From: Guy Jacob <guy.jacob@intel.com>
Date: Mon, 17 Feb 2020 10:01:37 +0200
Subject: [PATCH] Post-Train Quant LAPQ Refactoring (#473)

* Move image classification specific setup code to separate script at
  examples/classifier_compression/ptq_lapq.py
* Make ptq_coordinate_search function completely independent of
  command line arguments
* Change LAPQ command line args function to update existing
  pre-existing parser (changed CLAs perfix to 'lapq' for more clarity)
* Enable LAPQ from compress_classifier.py (trigger with --qe-lapq)
* Add pointers in documentation
---
 distiller/apputils/image_classifier.py        |   6 +-
 .../quantization/ptq_coordinate_search.py     | 294 +++++++-----------
 distiller/quantization/range_linear.py        |  42 +--
 examples/classifier_compression/README.md     |   1 +
 .../compress_classifier.py                    |  14 +-
 examples/classifier_compression/parser.py     |   2 +-
 examples/classifier_compression/ptq_lapq.py   |  99 ++++++
 .../post_train_quant/command_line.md          |   7 +-
 .../resnet18_imagenet_post_train_lapq.yaml    |  11 +-
 9 files changed, 264 insertions(+), 212 deletions(-)
 create mode 100644 examples/classifier_compression/ptq_lapq.py

diff --git a/distiller/apputils/image_classifier.py b/distiller/apputils/image_classifier.py
index 845a96f..c2e956f 100755
--- a/distiller/apputils/image_classifier.py
+++ b/distiller/apputils/image_classifier.py
@@ -204,13 +204,13 @@ class ClassifierCompressor(object):
                     self.pylogger, self.activations_collectors, args=self.args)
 
 
-def init_classifier_compression_arg_parser():
+def init_classifier_compression_arg_parser(include_ptq_lapq_args=False):
     '''Common classifier-compression application command-line arguments.
     '''
     SUMMARY_CHOICES = ['sparsity', 'compute', 'model', 'modules', 'png', 'png_w_params']
 
     parser = argparse.ArgumentParser(description='Distiller image classification model compression')
-    parser.add_argument('data', metavar='DIR', help='path to dataset')
+    parser.add_argument('data', metavar='DATASET_DIR', help='path to dataset')
     parser.add_argument('--arch', '-a', metavar='ARCH', default='resnet18', type=lambda s: s.lower(),
                         choices=models.ALL_MODEL_NAMES,
                         help='model architecture: ' +
@@ -312,7 +312,7 @@ def init_classifier_compression_arg_parser():
                         help='Load a model without DataParallel wrapping it')
     parser.add_argument('--thinnify', dest='thinnify', action='store_true', default=False,
                         help='physically remove zero-filters and create a smaller model')
-    distiller.quantization.add_post_train_quant_args(parser)
+    distiller.quantization.add_post_train_quant_args(parser, add_lapq_args=include_ptq_lapq_args)
     return parser
 
 
diff --git a/distiller/quantization/ptq_coordinate_search.py b/distiller/quantization/ptq_coordinate_search.py
index cd2f4df..ea8ecf8 100644
--- a/distiller/quantization/ptq_coordinate_search.py
+++ b/distiller/quantization/ptq_coordinate_search.py
@@ -1,5 +1,5 @@
 #
-# Copyright (c) 2019 Intel Corporation
+# Copyright (c) 2020 Intel Corporation
 #
 # Licensed under the Apache License, Version 2.0 (the "License");
 # you may not use this file except in compliance with the License.
@@ -36,11 +36,12 @@ from collections import OrderedDict
 from itertools import count
 import logging
 from copy import deepcopy
-import distiller.apputils.image_classifier as classifier
-import os
-import distiller.apputils as apputils
 import scipy.optimize as opt
 import numpy as np
+import argparse
+
+
+msglogger = logging.getLogger()
 
 
 def quant_params_dict2vec(p_dict, search_clipping=False):
@@ -172,8 +173,8 @@ def get_input_for_layer(model, layer_name, eval_fn):
     return layer_inputs[0]
 
 
-def init_layer_linear_quant_params(quantizer, original_model, layer_name, init_mode,
-                                   init_mode_method='Powell', eval_fn=None, search_clipping=False):
+def init_layer_linear_quant_params(quantizer, original_model, layer_name, init_mode=ClipMode.NONE,
+                                   init_method='Powell', eval_fn=None, search_clipping=False):
     """
     Initializes a layer's linear quant parameters.
     This is done to set the scipy.optimize.minimize initial guess.
@@ -190,7 +191,7 @@ def init_layer_linear_quant_params(quantizer, original_model, layer_name, init_m
           If str - the mode will be chosen from a list of options. The options are:
             [NONE, AVG, LAPLACE, GAUSS, L1, L2 ,L3].
           Defaults to ClipMode.NONE
-        init_mode_method (str or callable): applicable only in the case of init_mode = 'L1/2/3' or callable.
+        init_method (str or callable): applicable only in the case of init_mode = 'L1/2/3' or callable.
           chooses the minimization method for finding the local argmin_{s, zp}.
           Defaults to 'Powell'
         eval_fn: evaluation function for the model. Assumed it has a signature of the form
@@ -214,7 +215,7 @@ def init_layer_linear_quant_params(quantizer, original_model, layer_name, init_m
 
     if callable(init_mode):
         input_for_layer = get_input_for_layer(original_model, layer_name, eval_fn)
-        quantized_layer = optimize_for_layer(layer, quantized_layer, init_mode, input_for_layer, init_mode_method,
+        quantized_layer = optimize_for_layer(layer, quantized_layer, init_mode, input_for_layer, init_method,
                                              search_clipping=search_clipping)
 
     distiller.model_setattr(quantizer.model, denorm_layer_name, quantized_layer)
@@ -222,7 +223,7 @@ def init_layer_linear_quant_params(quantizer, original_model, layer_name, init_m
 
 
 def init_linear_quant_params(quantizer, original_model, eval_fn, dummy_input, init_mode,
-                             init_mode_method=None, search_clipping=False):
+                             init_method='Powell', search_clipping=False):
     """
     Initializes all linear quantization parameters of the model.
     Args:
@@ -235,7 +236,7 @@ def init_linear_quant_params(quantizer, original_model, eval_fn, dummy_input, in
           `eval_fn(model)->float`. this is the function to be minimized by the optimization algorithm.
           Note - unlike in `init_layer_linear_quant_params`, this argument is required here.
         dummy_input: dummy sample input to the model
-        init_mode_method: See `init_layer_linear_qaunt_params`.
+        init_method: See `init_layer_linear_qaunt_params`.
         search_clipping (bool): if set, optimize clipping values, otherwise optimize scale factor
     """
     original_model = distiller.make_non_parallel_copy(original_model)
@@ -249,7 +250,7 @@ def init_linear_quant_params(quantizer, original_model, eval_fn, dummy_input, in
         module_init_mode = init_mode[module_name] if isinstance(init_mode, dict) else init_mode
         msglogger.debug('Initializing layer \'%s\' using %s mode' % (module_name, module_init_mode))
         init_layer_linear_quant_params(quantizer, original_model, module_name, module_init_mode,
-                                       init_mode_method=init_mode_method,
+                                       init_method=init_method,
                                        eval_fn=eval_fn,
                                        search_clipping=search_clipping)
     del original_model
@@ -258,37 +259,55 @@ def init_linear_quant_params(quantizer, original_model, eval_fn, dummy_input, in
     quantizer.model.eval()
 
 
-def get_default_args():
-    parser = classifier.init_classifier_compression_arg_parser()
-    parser.add_argument('--opt-maxiter', dest='maxiter', default=None, type=int,
-                        help='Max iteration for minimization method.')
-    parser.add_argument('--opt-maxfev', dest='maxfev', default=None, type=int,
-                        help='Max iteration for minimization method.')
-    parser.add_argument('--opt-method', dest='method', default='Powell',
-                        help='Minimization method used by scip.optimize.minimize.')
-    parser.add_argument('--opt-bh', dest='basinhopping', action='store_true', default=False,
-                        help='Use scipy.optimize.basinhopping stochastic global minimum search.')
-    parser.add_argument('--opt-bh-niter', dest='niter', default=100,
-                        help='Number of iterations for the basinhopping algorithm.')
-    parser.add_argument('--opt-init-mode', dest='init_mode', default='NONE',
-                        choices=list(_INIT_MODES),
-                        help='The mode of quant initalization. Choices: ' + '|'.join(list(_INIT_MODES)))
-    parser.add_argument('--opt-init-method', dest='init_mode_method',
-                        help='If --opt-init-mode was specified as L1/L2/L3, this specifies the method of '
-                             'minimization.')
-    parser.add_argument('--opt-val-size', type=float, default=1,
-                        help='Use portion of the test size.')
-    parser.add_argument('--opt-eval-memoize-dataloader', dest='memoize_dataloader', action='store_true', default=False,
-                        help='Stores the input batch in memory to optimize performance.')
-    parser.add_argument('--base-score', type=float, default=None)
-    parser.add_argument('--opt-search-clipping', dest='search_clipping', action='store_true',
-                        help='Search on clipping values instead of scale/zero_point.')
-    args = parser.parse_args()
-    return args
-
-
-def validate_quantization_settings(args, quantized_model):
-    if args.search_clipping:
+def add_coordinate_search_args(parser: argparse.ArgumentParser):
+    group = parser.add_argument_group('Post-Training Quantization Auto-Optimization (LAPQ) Arguments')
+    group.add_argument('--lapq-maxiter', default=None, type=int,
+                       help='Max iteration for minimization method.')
+    group.add_argument('--lapq-maxfev', default=None, type=int,
+                       help='Max iteration for minimization method.')
+    group.add_argument('--lapq-method', default='Powell',
+                       help='Minimization method used by scip.optimize.minimize.')
+    group.add_argument('--lapq-basinhopping', '--lapq-bh', action='store_true', default=False,
+                       help='Use scipy.optimize.basinhopping stochastic global minimum search.')
+    group.add_argument('--lapq-basinhopping-niter', '--lapq-bh-niter', default=100,
+                       help='Number of iterations for the basinhopping algorithm.')
+    group.add_argument('--lapq-init-mode', default='NONE', choices=list(_INIT_MODES),
+                       help='The mode of quant initalization. Choices: ' + '|'.join(list(_INIT_MODES)))
+    group.add_argument('--lapq-init-method', default='Powell',
+                       help='If --lapq-init-mode was specified as L1/L2/L3, this specifies the method of '
+                            'minimization.')
+    group.add_argument('--lapq-eval-size', type=float, default=1,
+                       help='Portion of test dataset to use for evaluation function.')
+    group.add_argument('--lapq-eval-memoize-dataloader', action='store_true', default=False,
+                       help='Stores the input batch in memory to optimize performance.')
+    group.add_argument('--lapq-search-clipping', action='store_true',
+                       help='Search on clipping values instead of scale/zero_point.')
+
+
+def cmdline_args_to_dict(args):
+    """
+    Convenience function converting command line arguments obtained from add_coordinate_search_args
+    to a dictionary that can be passed as-is to ptq_coordinate_search.
+
+    Example:
+        # Assume pre-existing parser
+        add_coordinate_search_args(parser)
+        args = parser.parse_args()
+
+        # Assume quantizer, dummy_input, eval_fn, and test_fn have been set up
+        lapq_args_dict = cmdline_args_to_dict(args)
+        ptq_coordinate_search(quantizer, dummy_input, eval_fn, test_fn=test_fn, **lapq_args_dict)
+    """
+    prefix = 'lapq_'
+    len_prefix = len(prefix)
+    lapq_args = {k[len_prefix:]: v for k, v in vars(args).items() if k.startswith(prefix)}
+    lapq_args.pop('eval_size')
+    lapq_args.pop('eval_memoize_dataloader')
+    return lapq_args
+
+
+def validate_quantization_settings(quantized_model, search_clipping):
+    if search_clipping:
         return
     for n, m in quantized_model.named_modules():
         if not is_post_train_quant_wrapper(m, False):
@@ -306,55 +325,54 @@ def validate_quantization_settings(args, quantized_model):
                 raise ValueError(err_msg.format('weights'))
 
 
-def ptq_coordinate_search(model, dummy_input, eval_fn, method='Powell', options=None,
-                          act_stats=None, args=None, fold_sequences=True, basinhopping=False,
-                          init_args=None, minimizer_kwargs=None,
-                          test_fn=None):
+def ptq_coordinate_search(quantizer, dummy_input, eval_fn, test_fn=None, method='Powell',
+                          maxiter=None, maxfev=None, basinhopping=False, basinhopping_niter=100,
+                          init_mode=ClipMode.NONE, init_method=None, search_clipping=False,
+                          minimizer_kwargs=None):
     """
     Searches for the optimal post-train quantization configuration (scale/zero_points)
     for a model using numerical methods, as described by scipy.optimize.minimize.
     Args:
-        model (nn.Module): model to quantize
+        quantizer (distiller.quantization.PostTrainLinearQuantizer): A configured PostTrainLinearQuantizer object
+          containing the model being quantized
         dummy_input: an sample expected input to the model
         eval_fn (callable): evaluation function for the model. Assumed it has a signature of the form
           `eval_fn(model)->float`. this is the function to be minimized by the optimization algorithm.
-        method (str or callable): minimization method as accepted by scipy.optimize.minimize.
-        options (dict or None): options for the scipy optimizer
-        act_stats (OrderedDict): dictionary of statistics per layer, including inputs and outputs.
-          for more context refer to collect_quant_stats.
-        args: arguments from command-line.
-        fold_sequences (bool): flag, indicates to fold sequences before performing the search.
+        test_fn (callable): a function to test the current performance of the model. Assumed it has a signature of
+          the form `test_fn(model)->dict`, where the returned dict contains relevant results to be logged.
+          For example: {'top-1': VAL, 'top-5': VAL, 'loss': VAL}
+        method (str or callable): Minimization method as accepted by scipy.optimize.minimize.
+        maxiter (int): Maximum number of iterations to perform during minimization
+        maxfev (int): Maximum number of total function evaluations to perform during minimization
         basinhopping (bool): flag, indicates to use basinhopping as a global-minimization method,
           will pass the `method` argument to `scipy.optimize.basinhopping`.
-        init_args (tuple): arguments for initializing the linear quantization parameters.
-          Refer to `init_linear_quant_params` for more details.
-        minimizer_kwargs (dict): the kwargs for scipy.optimize.minimize procedure.
-        test_fn (callable): a function to test the current performance of the model.
+        basinhopping_niter (int): Number of iterations to perform if basinhopping is set
+        init_mode (ClipMode or callable or str or dict): See 'init_linear_quant_params'
+        init_method (str or callable): See 'init_layer_linear_quant_params'
+        search_clipping (bool): Search on clipping values instead of directly on scale/zero-point (scale and zero-
+          point are inferred from the clipping values)
+        minimizer_kwargs (dict): Optional additional arguments for scipy.optimize.minimize
     """
-    if fold_sequences:
-        model = fold_batch_norms(model, dummy_input)
-    if args is None:
-        args = get_default_args()
-    elif isinstance(args, dict):
-        updated_args = get_default_args()
-        updated_args.__dict__.update(args)
-        args = updated_args
-    original_model = deepcopy(model)
-
-    if not act_stats and not args.qe_config_file:
+    if not isinstance(quantizer, PostTrainLinearQuantizer):
+        raise ValueError('Only PostTrainLinearQuantizer supported, but got a {}'.format(quantizer.__class__.__name__))
+    if quantizer.prepared:
+        raise ValueError('Expecting a quantizer for which prepare_model has not been called')
+
+    original_model = deepcopy(quantizer.model)
+    original_model = fold_batch_norms(original_model, dummy_input)
+
+    if not quantizer.model_activation_stats:
         msglogger.info('Collecting stats for model...')
-        model_temp = distiller.utils.make_non_parallel_copy(model)
-        act_stats = collect_quant_stats(model_temp, eval_fn)
+        model_temp = distiller.utils.make_non_parallel_copy(original_model)
+        act_stats = collect_quant_stats(model_temp, eval_fn,
+                                        inplace_runtime_check=True, disable_inplace_attrs=True,
+                                        save_dir=getattr(msglogger, 'logdir', '.'))
         del model_temp
-        if args:
-            act_stats_path = '%s_act_stats.yaml' % args.arch
-            msglogger.info('Done. Saving act stats into %s' % act_stats_path)
-            distiller.yaml_ordered_save(act_stats_path, act_stats)
-            args.qe_stats_file = act_stats_path
+        quantizer.model_activation_stats = act_stats
+        quantizer.model.quantizer_metadata['params']['model_activation_stats'] = act_stats
 
     # Preparing model and init conditions:
     msglogger.info("Initializing quantizer...")
-    quantizer = PostTrainLinearQuantizer.from_args(model, args)
 
     # Make sure weights are re-quantizable and clip-able
     quantizer.save_fp_weights = True
@@ -368,26 +386,26 @@ def ptq_coordinate_search(model, dummy_input, eval_fn, method='Powell', options=
     quantizer.prepare_model(dummy_input)
     quantizer.model.eval()
 
-    validate_quantization_settings(args, quantizer.model)
+    validate_quantization_settings(quantizer.model, search_clipping)
 
     msglogger.info("Initializing quantization parameters...")
-    init_args = init_args or (args.init_mode, args.init_mode_method)
-    init_linear_quant_params(quantizer, original_model, eval_fn, dummy_input, *init_args,
-                             search_clipping=args.search_clipping)
+    init_linear_quant_params(quantizer, original_model, eval_fn, dummy_input, init_mode, init_method,
+                             search_clipping=search_clipping)
 
     msglogger.info("Evaluating initial quantization score...")
     best_data = {
-        'score': eval_fn(model),
+        'score': eval_fn(quantizer.model),
         'qp_dict': deepcopy(quantizer.linear_quant_params)
     }
     msglogger.info("Evaluation set loss after initialization %.3f" % best_data['score'])
     if test_fn:
         msglogger.info('Evaluating on full test set...')
-        l_top1, l_top5, l_loss = test_fn(quantizer.model)
-        msglogger.info('Test: \tloss=%.3f, top1=%.3f, top5=%.3f ' % (l_loss, l_top1, l_top5))
+        results = test_fn(quantizer.model)
+        s = ', '.join(['{} = {:.3f}'.format(k, v) for k, v in results.items()])
+        msglogger.info('Test: ' + s)
 
-    init_qp_dict = OrderedDict(quantizer.named_linear_quant_params(args.search_clipping, filter=True))
-    keys, init_qp_vec = quant_params_dict2vec(init_qp_dict, args.search_clipping)
+    init_qp_dict = OrderedDict(quantizer.named_linear_quant_params(search_clipping, filter=True))
+    keys, init_qp_vec = quant_params_dict2vec(init_qp_dict, search_clipping)
 
     iter_counter = count(1)
     eval_counter = count(1)
@@ -395,7 +413,7 @@ def ptq_coordinate_search(model, dummy_input, eval_fn, method='Powell', options=
     def feed_forward_fn(qp_vec):
         # if not _check_qp_vec(keys, qp_vec, quant_mode, args.search_clipping):
         #     return 1e6
-        qp_dict = quant_params_vec2dict(keys, qp_vec, args.search_clipping)
+        qp_dict = quant_params_vec2dict(keys, qp_vec, search_clipping)
         quantizer.update_linear_quant_params(qp_dict)
         loss = eval_fn(quantizer.model)
 
@@ -411,105 +429,31 @@ def ptq_coordinate_search(model, dummy_input, eval_fn, method='Powell', options=
         msglogger.info("Iteration %d: \t Score=%.3f" % (i, score))
         if score < best_data['score']:
             best_data['score'] = score
-            best_data['qp_dict'] = quant_params_vec2dict(keys, qp_vec, args.search_clipping)
+            best_data['qp_dict'] = quant_params_vec2dict(keys, qp_vec, search_clipping)
             msglogger.info("Saving current best quantization parameters.")
         if test_fn:
             msglogger.info('Evaluating on full test set...')
-            l_top1, l_top5, l_loss = test_fn(quantizer.model)
-            msglogger.info('Test: \tloss=%.3f, top1=%.3f, top5=%.3f ' % (l_loss, l_top1, l_top5))
-
-    options = options or OrderedDict()
-    if args.maxiter is not None:
-        options['maxiter'] = args.maxiter
-    if args.maxfev is not None:
-        options['maxfev'] = args.maxfev
+            results = test_fn(quantizer.model)
+            s = ', '.join(['{} = {:.3f}'.format(k, v) for k, v in results.items()])
+            msglogger.info('Test: ' + s)
+
+    options = OrderedDict()
+    options['maxiter'] = maxiter
+    options['maxfev'] = maxfev
+
     minimizer_kwargs = minimizer_kwargs or OrderedDict()
     minimizer_kwargs.update({
         'method': method, 'options': options
     })
-    basinhopping = basinhopping or args.basinhopping
     if basinhopping:
-        msglogger.info('Using basinhopping global minimum search with "%s" local minimization method'%
-                       method)
-        res = opt.basinhopping(feed_forward_fn, init_qp_vec, args.niter, callback=callback,
+        msglogger.info('Using basinhopping global minimum search with "%s" local minimization method' % method)
+        res = opt.basinhopping(feed_forward_fn, init_qp_vec, basinhopping_niter, callback=callback,
                                minimizer_kwargs=minimizer_kwargs)
     else:
         msglogger.info('Using "%s" minimization algorithm.' % method)
         res = opt.minimize(feed_forward_fn, init_qp_vec, callback=callback, **minimizer_kwargs)
 
-    msglogger.info("Optimization done. Best configuration: %s" % best_data['qp_dict'])
-    return model, best_data['qp_dict']
-
-
-if __name__ == "__main__":
-    args = get_default_args()
-    args.epochs = float('inf')  # hack for args parsing so there's no error in epochs
-    cc = classifier.ClassifierCompressor(args, script_dir=os.path.dirname(__file__))
-
-    args = deepcopy(cc.args)
-
-    effective_test_size_bak = args.effective_test_size
-    args.effective_test_size = args.opt_val_size
-    eval_data_loader = classifier.load_data(args, load_train=False, load_val=False, load_test=True, fixed_subset=True)
-
-    args.effective_test_size = effective_test_size_bak
-    test_data_loader = classifier.load_data(args, load_train=False, load_val=False, load_test=True)
-
-    # logging
-    logging.getLogger().setLevel(logging.WARNING)
-    msglogger = logging.getLogger(__name__)
-    msglogger.setLevel(logging.INFO)
-
-    model = cc.model.eval()
-    device = next(model.parameters()).device
-
-    if args.memoize_dataloader:
-        memoized_data_loader = []
-        for images, targets in eval_data_loader:
-            batch = images.to(device), targets.to(device)
-            memoized_data_loader.append(batch)
-    else:
-        memoized_data_loader = None
-
-    def eval_fn(model):
-        if args.memoize_dataloader:
-            loss = 0
-            for images, targets in memoized_data_loader:
-                outputs = model(images)
-                loss += cc.criterion(outputs, targets).item()
-            loss = loss / len(memoized_data_loader)
-        else:
-            _, _, loss = classifier.test(eval_data_loader, model, cc.criterion, [cc.tflogger, cc.pylogger],
-                                           None, args)
-        return loss
-
-    def test_fn(model):
-        return classifier.test(test_data_loader, model, cc.criterion, [cc.tflogger, cc.pylogger], None, args)
-
-    args.device = device
-    if args.resumed_checkpoint_path:
-        args.load_model_path = args.resumed_checkpoint_path
-    if args.load_model_path:
-        msglogger.info("Loading checkpoint from %s" % args.load_model_path)
-        model = apputils.load_lean_checkpoint(model, args.load_model_path,
-                                              model_device=args.device)
-
-    if args.qe_stats_file:
-        msglogger.info("Loading stats from %s" % args.qe_stats_file)
-        with open(args.qe_stats_file, 'r') as f:
-            act_stats = distiller.yaml_ordered_load(f)
-    else:
-        act_stats = None
-
-    dummy_input = torch.rand(*model.input_shape, device=args.device)
-    model, qp_dict = ptq_coordinate_search(model, dummy_input, eval_fn, args.method,
-                                           args=args, act_stats=act_stats, test_fn=test_fn)
-
-    top1, top5, loss = test_fn(model)
-
-    msglogger.info("Arch: %s \tTest: \t top1 = %.3f \t top5 = %.3f \t loss = %.3f" %
-                   (args.arch, top1, top5, loss))
-    distiller.yaml_ordered_save('%s.quant_params_dict.yaml' % args.arch, qp_dict)
-
-    distiller.apputils.save_checkpoint(0, args.arch, model, extras={'top1': top1, 'qp_dict': qp_dict}, name=args.name,
-                                       dir=cc.logdir)
+    msglogger.info('Optimization done')
+    msglogger.info('Best score: {}'.format(best_data['score']))
+    msglogger.info('Best Configuration: {}'.format(best_data['qp_dict']))
+    return quantizer.model, best_data['qp_dict']
diff --git a/distiller/quantization/range_linear.py b/distiller/quantization/range_linear.py
index ef92314..31bfc1d 100644
--- a/distiller/quantization/range_linear.py
+++ b/distiller/quantization/range_linear.py
@@ -268,7 +268,7 @@ def linear_dequantize_with_metadata(t, inplace=False):
     return t
 
 
-def add_post_train_quant_args(argparser):
+def add_post_train_quant_args(argparser, add_lapq_args=False):
     str_to_quant_mode_map = OrderedDict([
         ('sym', LinearQuantMode.SYMMETRIC),
         ('sym_restr', LinearQuantMode.SYMMETRIC_RESTRICTED),
@@ -293,8 +293,7 @@ def add_post_train_quant_args(argparser):
     linear_quant_mode_str_optional = partial(from_dict, d=str_to_quant_mode_map, optional=True)
     clip_mode_str = partial(from_dict, d=str_to_clip_mode_map, optional=False)
 
-    group = argparser.add_argument_group('Arguments controlling quantization at evaluation time '
-                                         '("post-training quantization")')
+    group = argparser.add_argument_group('Post-Training Quantization Arguments')
     group.add_argument('--quantize-eval', '--qe', action='store_true',
                        help='Apply linear quantization to model before evaluation. Applicable only if '
                             '--evaluate is also set')
@@ -338,15 +337,21 @@ def add_post_train_quant_args(argparser):
 
     stats_group = group.add_mutually_exclusive_group()
     stats_group.add_argument('--qe-stats-file', type=str, metavar='PATH',
-                       help='Path to YAML file with pre-made calibration stats')
+                             help='Path to YAML file with pre-made calibration stats')
     stats_group.add_argument('--qe-dynamic', action='store_true', help='Apply dynamic quantization')
     stats_group.add_argument('--qe-calibration', type=distiller.utils.float_range_argparse_checker(exc_min=True),
-                       metavar='PORTION_OF_TEST_SET', default=None,
-                       help='Run the model in evaluation mode on the specified portion of the test dataset and '
-                            'collect statistics')
+                             metavar='PORTION_OF_TEST_SET', default=None,
+                             help='Run the model in evaluation mode on the specified portion of the test dataset and '
+                                  'collect statistics')
     stats_group.add_argument('--qe-config-file', type=str, metavar='PATH',
-                       help='Path to YAML file containing configuration for PostTrainLinearQuantizer (if present, '
-                            'all other --qe* arguments are ignored)')
+                             help='Path to YAML file containing configuration for PostTrainRLinearQuantizer '
+                                  '(if present, all other --qe* arguments are ignored)')
+
+    if add_lapq_args:
+        from .ptq_coordinate_search import add_coordinate_search_args
+        group.add_argument('--qe-lapq', '--qe-coordinate-search', action='store_true',
+                           help='Optimize post-training quantization parameters using LAPQ method')
+        add_coordinate_search_args(argparser)
 
 
 class UnsatisfiedRequirements(Exception):
@@ -1597,15 +1602,6 @@ class PostTrainLinearQuantizer(Quantizer):
                     model_activation_stats = distiller.utils.yaml_ordered_load(stream)
             elif not isinstance(model_activation_stats, (dict, OrderedDict)):
                 raise TypeError('model_activation_stats must either be a string, a dict / OrderedDict or None')
-        else:
-            msglogger.warning("\nWARNING:\nNo stats file passed - Dynamic quantization will be used\n"
-                              "At the moment, this mode isn't as fully featured as stats-based quantization, and "
-                              "the accuracy results obtained are likely not as representative of real-world results."
-                              "\nSpecifically:\n"
-                              "  * Not all modules types are supported in this mode. Unsupported modules will remain "
-                              "in FP32.\n"
-                              "  * Optimizations for quantization of layers followed by Relu/Tanh/Sigmoid are only "
-                              "supported when statistics are used.\nEND WARNING\n")
 
         mode_dict = {'activations': _enum_to_str(mode.activations), 'weights': _enum_to_str(mode.weights)}
         self.model.quantizer_metadata = {'type': type(self),
@@ -1928,6 +1924,16 @@ class PostTrainLinearQuantizer(Quantizer):
         msglogger.info('Per-layer quantization parameters saved to ' + save_path)
 
     def prepare_model(self, dummy_input=None):
+        if not self.model_activation_stats:
+            msglogger.warning("\nWARNING:\nNo stats file passed - Dynamic quantization will be used\n"
+                              "At the moment, this mode isn't as fully featured as stats-based quantization, and "
+                              "the accuracy results obtained are likely not as representative of real-world results."
+                              "\nSpecifically:\n"
+                              "  * Not all modules types are supported in this mode. Unsupported modules will remain "
+                              "in FP32.\n"
+                              "  * Optimizations for quantization of layers followed by Relu/Tanh/Sigmoid are only "
+                              "supported when statistics are used.\nEND WARNING\n")
+
         self.has_bidi_distiller_lstm = any(isinstance(m, distiller.modules.DistillerLSTM) and m.bidirectional for
                                            _, m in self.model.named_modules())
         if self.has_bidi_distiller_lstm:
diff --git a/examples/classifier_compression/README.md b/examples/classifier_compression/README.md
index 2c72e02..f893b2d 100644
--- a/examples/classifier_compression/README.md
+++ b/examples/classifier_compression/README.md
@@ -34,6 +34,7 @@ A non-exhaustive list of the methods implemented:
 ### Quantization
 
 - [Post-training quantization](https://github.com/NervanaSystems/distiller/tree/master/examples/quantization/post_train_quant/command_line.md) based on the TensorFlow quantization scheme (originally GEMMLOWP) with additional capabilities.
+  - Optimizing post-training quantization parameters with the [LAPQ](https://arxiv.org/abs/1911.07190) method - see [example YAML](https://github.com/NervanaSystems/distiller/blob/master/examples/quantization/post_train_quant/resnet18_imagenet_post_train_lapq.yaml) file for details.
 - [Quantization-aware training](https://github.com/NervanaSystems/distiller/tree/master/examples/quantization/quant_aware_train): TensorFlow scheme, DoReFa, PACT
 
 ### Knowledge Distillation
diff --git a/examples/classifier_compression/compress_classifier.py b/examples/classifier_compression/compress_classifier.py
index a6904f7..88911c4 100755
--- a/examples/classifier_compression/compress_classifier.py
+++ b/examples/classifier_compression/compress_classifier.py
@@ -61,6 +61,7 @@ import distiller.apputils as apputils
 import parser
 import os
 import numpy as np
+from ptq_lapq import image_classifier_ptq_lapq
 
 
 # Logger handle
@@ -69,7 +70,7 @@ msglogger = logging.getLogger()
 
 def main():
     # Parse arguments
-    args = parser.add_cmdline_args(classifier.init_classifier_compression_arg_parser()).parse_args()
+    args = parser.add_cmdline_args(classifier.init_classifier_compression_arg_parser(True)).parse_args()
     app = ClassifierCompressorSampleApp(args, script_dir=os.path.dirname(__file__))
     if app.handle_subapps():
         return
@@ -110,10 +111,13 @@ def handle_subapps(model, criterion, optimizer, compression_scheduler, pylogger,
         sensitivity_analysis(model, criterion, test_loader, pylogger, args, sensitivities)
         do_exit = True
     elif args.evaluate:
-        test_loader = load_test_data(args)
-        classifier.evaluate_model(test_loader, model, criterion, pylogger,
-            classifier.create_activation_stats_collectors(model, *args.activation_stats),
-            args, scheduler=compression_scheduler)
+        if args.quantize_eval and args.qe_lapq:
+            image_classifier_ptq_lapq(model, criterion, pylogger, args)
+        else:
+            test_loader = load_test_data(args)
+            classifier.evaluate_model(test_loader, model, criterion, pylogger,
+                classifier.create_activation_stats_collectors(model, *args.activation_stats),
+                args, scheduler=compression_scheduler)
         do_exit = True
     elif args.thinnify:
         assert args.resumed_checkpoint_path is not None, \
diff --git a/examples/classifier_compression/parser.py b/examples/classifier_compression/parser.py
index 0385d0e..5697787 100755
--- a/examples/classifier_compression/parser.py
+++ b/examples/classifier_compression/parser.py
@@ -16,7 +16,7 @@
 
 import argparse
 import distiller
-import distiller.quantization
+import distiller.pruning
 import distiller.models as models
 
 
diff --git a/examples/classifier_compression/ptq_lapq.py b/examples/classifier_compression/ptq_lapq.py
new file mode 100644
index 0000000..fa79305
--- /dev/null
+++ b/examples/classifier_compression/ptq_lapq.py
@@ -0,0 +1,99 @@
+#
+# Copyright (c) 2020 Intel Corporation
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#      http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+import os
+import torch
+from copy import deepcopy
+import logging
+from collections import OrderedDict
+
+import distiller
+import distiller.apputils as apputils
+import distiller.apputils.image_classifier as classifier
+import distiller.quantization.ptq_coordinate_search as lapq
+
+
+msglogger = logging.getLogger()
+
+
+def image_classifier_ptq_lapq(model, criterion, loggers, args):
+    args = deepcopy(args)
+
+    effective_test_size_bak = args.effective_test_size
+    args.effective_test_size = args.lapq_eval_size
+    eval_data_loader = classifier.load_data(args, load_train=False, load_val=False, load_test=True, fixed_subset=True)
+
+    args.effective_test_size = effective_test_size_bak
+    test_data_loader = classifier.load_data(args, load_train=False, load_val=False, load_test=True)
+
+    model = model.eval()
+    device = next(model.parameters()).device
+
+    if args.lapq_eval_memoize_dataloader:
+        images_batches = []
+        targets_batches = []
+        for images, targets in eval_data_loader:
+            images_batches.append(images.to(device))
+            targets_batches.append(targets.to(device))
+        memoized_data_loader = [(torch.cat(images_batches), torch.cat(targets_batches))]
+    else:
+        memoized_data_loader = None
+
+    def eval_fn(model):
+        if memoized_data_loader:
+            loss = 0
+            for images, targets in memoized_data_loader:
+                outputs = model(images)
+                loss += criterion(outputs, targets).item()
+            loss = loss / len(memoized_data_loader)
+        else:
+            _, _, loss = classifier.test(eval_data_loader, model, criterion, loggers, None, args)
+        return loss
+
+    def test_fn(model):
+        top1, top5, loss = classifier.test(test_data_loader, model, criterion, loggers, None, args)
+        return OrderedDict([('top-1', top1), ('top-5', top5), ('loss', loss)])
+
+    args.device = device
+    if args.resumed_checkpoint_path:
+        args.load_model_path = args.resumed_checkpoint_path
+    if args.load_model_path:
+        msglogger.info("Loading checkpoint from %s" % args.load_model_path)
+        model = apputils.load_lean_checkpoint(model, args.load_model_path,
+                                              model_device=args.device)
+
+    quantizer = distiller.quantization.PostTrainLinearQuantizer.from_args(model, args)
+
+    dummy_input = torch.rand(*model.input_shape, device=args.device)
+    model, qp_dict = lapq.ptq_coordinate_search(quantizer, dummy_input, eval_fn, test_fn=test_fn,
+                                                **lapq.cmdline_args_to_dict(args))
+
+    results = test_fn(quantizer.model)
+    msglogger.info("Arch: %s \tTest: \t top1 = %.3f \t top5 = %.3f \t loss = %.3f" %
+                   (args.arch, results['top-1'], results['top-5'], results['loss']))
+    distiller.yaml_ordered_save('%s.quant_params_dict.yaml' % args.arch, qp_dict)
+
+    distiller.apputils.save_checkpoint(0, args.arch, model,
+                                       extras={'top1': results['top-1'], 'qp_dict': qp_dict}, name=args.name,
+                                       dir=msglogger.logdir)
+
+
+if __name__ == "__main__":
+    parser = classifier.init_classifier_compression_arg_parser(include_ptq_lapq_args=True)
+    args = parser.parse_args()
+    args.epochs = float('inf')  # hack for args parsing so there's no error in epochs
+    cc = classifier.ClassifierCompressor(args, script_dir=os.path.dirname(__file__))
+    image_classifier_ptq_lapq(cc.model, cc.criterion, [cc.pylogger, cc.tflogger], cc.args)
diff --git a/examples/quantization/post_train_quant/command_line.md b/examples/quantization/post_train_quant/command_line.md
index ce2c948..2c112eb 100644
--- a/examples/quantization/post_train_quant/command_line.md
+++ b/examples/quantization/post_train_quant/command_line.md
@@ -30,8 +30,9 @@ Post-training quantization can either be configured straight from the command-li
 | `--qe-stats-file`        | N/A       | Use stats file for static quantization of activations. See details below              | None    |
 | `--qe-dynamic`           | N/A       | Perform dynamic quantization. See details below                                       | None    |
 | `--qe-config-file`       | N/A       | Path to YAML config file. See section above. (ignores all other --qe* arguments)      | None    |
-| `--qe-convert-pytorch`   | `--qept`  | Convert the model to PyTorch native post-train quantization modules                   | Off     |
+| `--qe-convert-pytorch`   | `--qept`  | Convert the model to PyTorch native post-train quantization modules. See [tutorial](https://github.com/NervanaSystems/distiller/blob/master/jupyter/post_train_quant_convert_pytorch.ipynb) for more details | Off     |
 | `--qe-pytorch-backend`   | N/A       | When --qe-convert-pytorch is set, specifies the PyTorch quantization backend to use. Choices: "fbgemm", "qnnpack"   | Off     |
+| `--qe-lapq`              | N/A       | Optimize post-training quantization parameters using [LAPQ](https://arxiv.org/abs/1911.07190) method. Beyond the scope of this document. See [example YAML](https://github.com/NervanaSystems/distiller/blob/master/examples/quantization/post_train_quant/resnet18_imagenet_post_train_lapq.yaml) file for details   | Off     |
 
 ### Notes
 
@@ -40,10 +41,6 @@ Post-training quantization can either be configured straight from the command-li
     * `--quantize-eval` is also set, in which case an FP32 model is first quantized using Distiller's post-training quantization flow, and then converted to a PyTorch native quantization model.
     * `--quantize-eval` is not set, but a previously post-train quantized model is loaded via `--resume`. In this case, the loaded model is converted to PyTorch native quantization.
 
-### Conversion to PyTorch Built-in Quantization Model
-
-PyTorch released built-in support for quantization in version 1.3. Currently Distiller's quantization functionality is still completely separate from PyTorch's. We provide the ability to take a model which was post-train quantized with Distiller, and is comprised of `RangeLinearQuantWrapper`
-
 ## "Net-Aware" Quantization
 
 The term "net-aware" quantization, coined in [this](https://arxiv.org/abs/1811.09886) paper from Facebook (section 3.2.2), means we can achieve better quantization by considering sequences of operations instead of just quantizing each operation independently. This isn't exactly layer fusion - in Distiller we modify activation stats prior to setting quantization parameters, in to make sure that when a module is followed by certain activation functions, only the relevant ranges are quantized. We do this for:
diff --git a/examples/quantization/post_train_quant/resnet18_imagenet_post_train_lapq.yaml b/examples/quantization/post_train_quant/resnet18_imagenet_post_train_lapq.yaml
index 87b42f2..b920326 100644
--- a/examples/quantization/post_train_quant/resnet18_imagenet_post_train_lapq.yaml
+++ b/examples/quantization/post_train_quant/resnet18_imagenet_post_train_lapq.yaml
@@ -21,7 +21,7 @@ quantizers:
     mode:
       activations: ASYMMETRIC_UNSIGNED
       weights: SYMMETRIC
-    model_activation_stats: ../../examples/quantization/post_train_quant/stats/resnet18_quant_stats.yaml
+    # model_activation_stats: ../quantization/post_train_quant/stats/resnet18_quant_stats.yaml
     per_channel_wts: False
     inputs_quant_auto_fallback: True
 
@@ -47,12 +47,13 @@ quantizers:
 
 # Example invocations:
 #   * Preliminaries:
-#       cd <distiller_root>/distiller/quantization
-#       CONFIG_FILE="../../examples/quantization/post_train_quant/resnet18_imagenet_post_train_lapq.yaml"
+#       cd <distiller_root>/examples/classifier_compression
+#       CONFIG_FILE="../quantization/post_train_quant/resnet18_imagenet_post_train_lapq.yaml"
+#       IMAGENET_PATH=<path_to_imagenet>
 #
 #   * Using L3 initialization:
 #     Command:
-#       python ptq_coordinate_search.py -a resnet18 --pretrained <path_to_imagenet> --opt-val-size 0.01 --opt-maxiter 2 --qe-config-file $CONFIG_FILE -b 500 --opt-init-mode L3 --opt-init-method powell --opt-eval-memoize-dataloader --det --opt-search-clipping
+#       python compress_classifier.py --eval --qe --qe-lapq -a resnet18 --pretrained $IMAGENET_PATH --lapq-eval-size 0.01 --lapq-maxiter 2 --qe-config-file $CONFIG_FILE -b 500 --lapq-init-mode L3 --lapq-init-method powell --lapq-eval-memoize-dataloader --det --lapq-search-clipping
 #
 #     Excerpts from output:
 #       ...
@@ -77,7 +78,7 @@ quantizers:
 #
 #   * Using LAPLACE initialization:
 #     Command:
-#       python ptq_coordinate_search.py -a resnet18 --pretrained <path_to_imagenet> --opt-val-size 0.01 --opt-maxiter 2 --qe-config-file $CONFIG_FILE -b 500 --opt-init-mode LAPLACE --opt-init-method powell --opt-eval-memoize-dataloader --det --opt-search-clipping
+#       python compress_classifier.py --eval --qe --qe-lapq -a resnet18 --pretrained $IMAGENET_PATH --lapq-eval-size 0.01 --lapq-maxiter 2 --qe-config-file $CONFIG_FILE -b 500 --lapq-init-mode LAPLACE --lapq-init-method powell --lapq-eval-memoize-dataloader --det --lapq-search-clipping
 #
 #     Excerpts from output:
 #       ...
-- 
GitLab