From 394e3bc616fcf3548e0fe486bbf6e5170479c8dc Mon Sep 17 00:00:00 2001 From: Guy Jacob <guy.jacob@intel.com> Date: Mon, 17 Feb 2020 10:01:37 +0200 Subject: [PATCH] Post-Train Quant LAPQ Refactoring (#473) * Move image classification specific setup code to separate script at examples/classifier_compression/ptq_lapq.py * Make ptq_coordinate_search function completely independent of command line arguments * Change LAPQ command line args function to update existing pre-existing parser (changed CLAs perfix to 'lapq' for more clarity) * Enable LAPQ from compress_classifier.py (trigger with --qe-lapq) * Add pointers in documentation --- distiller/apputils/image_classifier.py | 6 +- .../quantization/ptq_coordinate_search.py | 294 +++++++----------- distiller/quantization/range_linear.py | 42 +-- examples/classifier_compression/README.md | 1 + .../compress_classifier.py | 14 +- examples/classifier_compression/parser.py | 2 +- examples/classifier_compression/ptq_lapq.py | 99 ++++++ .../post_train_quant/command_line.md | 7 +- .../resnet18_imagenet_post_train_lapq.yaml | 11 +- 9 files changed, 264 insertions(+), 212 deletions(-) create mode 100644 examples/classifier_compression/ptq_lapq.py diff --git a/distiller/apputils/image_classifier.py b/distiller/apputils/image_classifier.py index 845a96f..c2e956f 100755 --- a/distiller/apputils/image_classifier.py +++ b/distiller/apputils/image_classifier.py @@ -204,13 +204,13 @@ class ClassifierCompressor(object): self.pylogger, self.activations_collectors, args=self.args) -def init_classifier_compression_arg_parser(): +def init_classifier_compression_arg_parser(include_ptq_lapq_args=False): '''Common classifier-compression application command-line arguments. ''' SUMMARY_CHOICES = ['sparsity', 'compute', 'model', 'modules', 'png', 'png_w_params'] parser = argparse.ArgumentParser(description='Distiller image classification model compression') - parser.add_argument('data', metavar='DIR', help='path to dataset') + parser.add_argument('data', metavar='DATASET_DIR', help='path to dataset') parser.add_argument('--arch', '-a', metavar='ARCH', default='resnet18', type=lambda s: s.lower(), choices=models.ALL_MODEL_NAMES, help='model architecture: ' + @@ -312,7 +312,7 @@ def init_classifier_compression_arg_parser(): help='Load a model without DataParallel wrapping it') parser.add_argument('--thinnify', dest='thinnify', action='store_true', default=False, help='physically remove zero-filters and create a smaller model') - distiller.quantization.add_post_train_quant_args(parser) + distiller.quantization.add_post_train_quant_args(parser, add_lapq_args=include_ptq_lapq_args) return parser diff --git a/distiller/quantization/ptq_coordinate_search.py b/distiller/quantization/ptq_coordinate_search.py index cd2f4df..ea8ecf8 100644 --- a/distiller/quantization/ptq_coordinate_search.py +++ b/distiller/quantization/ptq_coordinate_search.py @@ -1,5 +1,5 @@ # -# Copyright (c) 2019 Intel Corporation +# Copyright (c) 2020 Intel Corporation # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -36,11 +36,12 @@ from collections import OrderedDict from itertools import count import logging from copy import deepcopy -import distiller.apputils.image_classifier as classifier -import os -import distiller.apputils as apputils import scipy.optimize as opt import numpy as np +import argparse + + +msglogger = logging.getLogger() def quant_params_dict2vec(p_dict, search_clipping=False): @@ -172,8 +173,8 @@ def get_input_for_layer(model, layer_name, eval_fn): return layer_inputs[0] -def init_layer_linear_quant_params(quantizer, original_model, layer_name, init_mode, - init_mode_method='Powell', eval_fn=None, search_clipping=False): +def init_layer_linear_quant_params(quantizer, original_model, layer_name, init_mode=ClipMode.NONE, + init_method='Powell', eval_fn=None, search_clipping=False): """ Initializes a layer's linear quant parameters. This is done to set the scipy.optimize.minimize initial guess. @@ -190,7 +191,7 @@ def init_layer_linear_quant_params(quantizer, original_model, layer_name, init_m If str - the mode will be chosen from a list of options. The options are: [NONE, AVG, LAPLACE, GAUSS, L1, L2 ,L3]. Defaults to ClipMode.NONE - init_mode_method (str or callable): applicable only in the case of init_mode = 'L1/2/3' or callable. + init_method (str or callable): applicable only in the case of init_mode = 'L1/2/3' or callable. chooses the minimization method for finding the local argmin_{s, zp}. Defaults to 'Powell' eval_fn: evaluation function for the model. Assumed it has a signature of the form @@ -214,7 +215,7 @@ def init_layer_linear_quant_params(quantizer, original_model, layer_name, init_m if callable(init_mode): input_for_layer = get_input_for_layer(original_model, layer_name, eval_fn) - quantized_layer = optimize_for_layer(layer, quantized_layer, init_mode, input_for_layer, init_mode_method, + quantized_layer = optimize_for_layer(layer, quantized_layer, init_mode, input_for_layer, init_method, search_clipping=search_clipping) distiller.model_setattr(quantizer.model, denorm_layer_name, quantized_layer) @@ -222,7 +223,7 @@ def init_layer_linear_quant_params(quantizer, original_model, layer_name, init_m def init_linear_quant_params(quantizer, original_model, eval_fn, dummy_input, init_mode, - init_mode_method=None, search_clipping=False): + init_method='Powell', search_clipping=False): """ Initializes all linear quantization parameters of the model. Args: @@ -235,7 +236,7 @@ def init_linear_quant_params(quantizer, original_model, eval_fn, dummy_input, in `eval_fn(model)->float`. this is the function to be minimized by the optimization algorithm. Note - unlike in `init_layer_linear_quant_params`, this argument is required here. dummy_input: dummy sample input to the model - init_mode_method: See `init_layer_linear_qaunt_params`. + init_method: See `init_layer_linear_qaunt_params`. search_clipping (bool): if set, optimize clipping values, otherwise optimize scale factor """ original_model = distiller.make_non_parallel_copy(original_model) @@ -249,7 +250,7 @@ def init_linear_quant_params(quantizer, original_model, eval_fn, dummy_input, in module_init_mode = init_mode[module_name] if isinstance(init_mode, dict) else init_mode msglogger.debug('Initializing layer \'%s\' using %s mode' % (module_name, module_init_mode)) init_layer_linear_quant_params(quantizer, original_model, module_name, module_init_mode, - init_mode_method=init_mode_method, + init_method=init_method, eval_fn=eval_fn, search_clipping=search_clipping) del original_model @@ -258,37 +259,55 @@ def init_linear_quant_params(quantizer, original_model, eval_fn, dummy_input, in quantizer.model.eval() -def get_default_args(): - parser = classifier.init_classifier_compression_arg_parser() - parser.add_argument('--opt-maxiter', dest='maxiter', default=None, type=int, - help='Max iteration for minimization method.') - parser.add_argument('--opt-maxfev', dest='maxfev', default=None, type=int, - help='Max iteration for minimization method.') - parser.add_argument('--opt-method', dest='method', default='Powell', - help='Minimization method used by scip.optimize.minimize.') - parser.add_argument('--opt-bh', dest='basinhopping', action='store_true', default=False, - help='Use scipy.optimize.basinhopping stochastic global minimum search.') - parser.add_argument('--opt-bh-niter', dest='niter', default=100, - help='Number of iterations for the basinhopping algorithm.') - parser.add_argument('--opt-init-mode', dest='init_mode', default='NONE', - choices=list(_INIT_MODES), - help='The mode of quant initalization. Choices: ' + '|'.join(list(_INIT_MODES))) - parser.add_argument('--opt-init-method', dest='init_mode_method', - help='If --opt-init-mode was specified as L1/L2/L3, this specifies the method of ' - 'minimization.') - parser.add_argument('--opt-val-size', type=float, default=1, - help='Use portion of the test size.') - parser.add_argument('--opt-eval-memoize-dataloader', dest='memoize_dataloader', action='store_true', default=False, - help='Stores the input batch in memory to optimize performance.') - parser.add_argument('--base-score', type=float, default=None) - parser.add_argument('--opt-search-clipping', dest='search_clipping', action='store_true', - help='Search on clipping values instead of scale/zero_point.') - args = parser.parse_args() - return args - - -def validate_quantization_settings(args, quantized_model): - if args.search_clipping: +def add_coordinate_search_args(parser: argparse.ArgumentParser): + group = parser.add_argument_group('Post-Training Quantization Auto-Optimization (LAPQ) Arguments') + group.add_argument('--lapq-maxiter', default=None, type=int, + help='Max iteration for minimization method.') + group.add_argument('--lapq-maxfev', default=None, type=int, + help='Max iteration for minimization method.') + group.add_argument('--lapq-method', default='Powell', + help='Minimization method used by scip.optimize.minimize.') + group.add_argument('--lapq-basinhopping', '--lapq-bh', action='store_true', default=False, + help='Use scipy.optimize.basinhopping stochastic global minimum search.') + group.add_argument('--lapq-basinhopping-niter', '--lapq-bh-niter', default=100, + help='Number of iterations for the basinhopping algorithm.') + group.add_argument('--lapq-init-mode', default='NONE', choices=list(_INIT_MODES), + help='The mode of quant initalization. Choices: ' + '|'.join(list(_INIT_MODES))) + group.add_argument('--lapq-init-method', default='Powell', + help='If --lapq-init-mode was specified as L1/L2/L3, this specifies the method of ' + 'minimization.') + group.add_argument('--lapq-eval-size', type=float, default=1, + help='Portion of test dataset to use for evaluation function.') + group.add_argument('--lapq-eval-memoize-dataloader', action='store_true', default=False, + help='Stores the input batch in memory to optimize performance.') + group.add_argument('--lapq-search-clipping', action='store_true', + help='Search on clipping values instead of scale/zero_point.') + + +def cmdline_args_to_dict(args): + """ + Convenience function converting command line arguments obtained from add_coordinate_search_args + to a dictionary that can be passed as-is to ptq_coordinate_search. + + Example: + # Assume pre-existing parser + add_coordinate_search_args(parser) + args = parser.parse_args() + + # Assume quantizer, dummy_input, eval_fn, and test_fn have been set up + lapq_args_dict = cmdline_args_to_dict(args) + ptq_coordinate_search(quantizer, dummy_input, eval_fn, test_fn=test_fn, **lapq_args_dict) + """ + prefix = 'lapq_' + len_prefix = len(prefix) + lapq_args = {k[len_prefix:]: v for k, v in vars(args).items() if k.startswith(prefix)} + lapq_args.pop('eval_size') + lapq_args.pop('eval_memoize_dataloader') + return lapq_args + + +def validate_quantization_settings(quantized_model, search_clipping): + if search_clipping: return for n, m in quantized_model.named_modules(): if not is_post_train_quant_wrapper(m, False): @@ -306,55 +325,54 @@ def validate_quantization_settings(args, quantized_model): raise ValueError(err_msg.format('weights')) -def ptq_coordinate_search(model, dummy_input, eval_fn, method='Powell', options=None, - act_stats=None, args=None, fold_sequences=True, basinhopping=False, - init_args=None, minimizer_kwargs=None, - test_fn=None): +def ptq_coordinate_search(quantizer, dummy_input, eval_fn, test_fn=None, method='Powell', + maxiter=None, maxfev=None, basinhopping=False, basinhopping_niter=100, + init_mode=ClipMode.NONE, init_method=None, search_clipping=False, + minimizer_kwargs=None): """ Searches for the optimal post-train quantization configuration (scale/zero_points) for a model using numerical methods, as described by scipy.optimize.minimize. Args: - model (nn.Module): model to quantize + quantizer (distiller.quantization.PostTrainLinearQuantizer): A configured PostTrainLinearQuantizer object + containing the model being quantized dummy_input: an sample expected input to the model eval_fn (callable): evaluation function for the model. Assumed it has a signature of the form `eval_fn(model)->float`. this is the function to be minimized by the optimization algorithm. - method (str or callable): minimization method as accepted by scipy.optimize.minimize. - options (dict or None): options for the scipy optimizer - act_stats (OrderedDict): dictionary of statistics per layer, including inputs and outputs. - for more context refer to collect_quant_stats. - args: arguments from command-line. - fold_sequences (bool): flag, indicates to fold sequences before performing the search. + test_fn (callable): a function to test the current performance of the model. Assumed it has a signature of + the form `test_fn(model)->dict`, where the returned dict contains relevant results to be logged. + For example: {'top-1': VAL, 'top-5': VAL, 'loss': VAL} + method (str or callable): Minimization method as accepted by scipy.optimize.minimize. + maxiter (int): Maximum number of iterations to perform during minimization + maxfev (int): Maximum number of total function evaluations to perform during minimization basinhopping (bool): flag, indicates to use basinhopping as a global-minimization method, will pass the `method` argument to `scipy.optimize.basinhopping`. - init_args (tuple): arguments for initializing the linear quantization parameters. - Refer to `init_linear_quant_params` for more details. - minimizer_kwargs (dict): the kwargs for scipy.optimize.minimize procedure. - test_fn (callable): a function to test the current performance of the model. + basinhopping_niter (int): Number of iterations to perform if basinhopping is set + init_mode (ClipMode or callable or str or dict): See 'init_linear_quant_params' + init_method (str or callable): See 'init_layer_linear_quant_params' + search_clipping (bool): Search on clipping values instead of directly on scale/zero-point (scale and zero- + point are inferred from the clipping values) + minimizer_kwargs (dict): Optional additional arguments for scipy.optimize.minimize """ - if fold_sequences: - model = fold_batch_norms(model, dummy_input) - if args is None: - args = get_default_args() - elif isinstance(args, dict): - updated_args = get_default_args() - updated_args.__dict__.update(args) - args = updated_args - original_model = deepcopy(model) - - if not act_stats and not args.qe_config_file: + if not isinstance(quantizer, PostTrainLinearQuantizer): + raise ValueError('Only PostTrainLinearQuantizer supported, but got a {}'.format(quantizer.__class__.__name__)) + if quantizer.prepared: + raise ValueError('Expecting a quantizer for which prepare_model has not been called') + + original_model = deepcopy(quantizer.model) + original_model = fold_batch_norms(original_model, dummy_input) + + if not quantizer.model_activation_stats: msglogger.info('Collecting stats for model...') - model_temp = distiller.utils.make_non_parallel_copy(model) - act_stats = collect_quant_stats(model_temp, eval_fn) + model_temp = distiller.utils.make_non_parallel_copy(original_model) + act_stats = collect_quant_stats(model_temp, eval_fn, + inplace_runtime_check=True, disable_inplace_attrs=True, + save_dir=getattr(msglogger, 'logdir', '.')) del model_temp - if args: - act_stats_path = '%s_act_stats.yaml' % args.arch - msglogger.info('Done. Saving act stats into %s' % act_stats_path) - distiller.yaml_ordered_save(act_stats_path, act_stats) - args.qe_stats_file = act_stats_path + quantizer.model_activation_stats = act_stats + quantizer.model.quantizer_metadata['params']['model_activation_stats'] = act_stats # Preparing model and init conditions: msglogger.info("Initializing quantizer...") - quantizer = PostTrainLinearQuantizer.from_args(model, args) # Make sure weights are re-quantizable and clip-able quantizer.save_fp_weights = True @@ -368,26 +386,26 @@ def ptq_coordinate_search(model, dummy_input, eval_fn, method='Powell', options= quantizer.prepare_model(dummy_input) quantizer.model.eval() - validate_quantization_settings(args, quantizer.model) + validate_quantization_settings(quantizer.model, search_clipping) msglogger.info("Initializing quantization parameters...") - init_args = init_args or (args.init_mode, args.init_mode_method) - init_linear_quant_params(quantizer, original_model, eval_fn, dummy_input, *init_args, - search_clipping=args.search_clipping) + init_linear_quant_params(quantizer, original_model, eval_fn, dummy_input, init_mode, init_method, + search_clipping=search_clipping) msglogger.info("Evaluating initial quantization score...") best_data = { - 'score': eval_fn(model), + 'score': eval_fn(quantizer.model), 'qp_dict': deepcopy(quantizer.linear_quant_params) } msglogger.info("Evaluation set loss after initialization %.3f" % best_data['score']) if test_fn: msglogger.info('Evaluating on full test set...') - l_top1, l_top5, l_loss = test_fn(quantizer.model) - msglogger.info('Test: \tloss=%.3f, top1=%.3f, top5=%.3f ' % (l_loss, l_top1, l_top5)) + results = test_fn(quantizer.model) + s = ', '.join(['{} = {:.3f}'.format(k, v) for k, v in results.items()]) + msglogger.info('Test: ' + s) - init_qp_dict = OrderedDict(quantizer.named_linear_quant_params(args.search_clipping, filter=True)) - keys, init_qp_vec = quant_params_dict2vec(init_qp_dict, args.search_clipping) + init_qp_dict = OrderedDict(quantizer.named_linear_quant_params(search_clipping, filter=True)) + keys, init_qp_vec = quant_params_dict2vec(init_qp_dict, search_clipping) iter_counter = count(1) eval_counter = count(1) @@ -395,7 +413,7 @@ def ptq_coordinate_search(model, dummy_input, eval_fn, method='Powell', options= def feed_forward_fn(qp_vec): # if not _check_qp_vec(keys, qp_vec, quant_mode, args.search_clipping): # return 1e6 - qp_dict = quant_params_vec2dict(keys, qp_vec, args.search_clipping) + qp_dict = quant_params_vec2dict(keys, qp_vec, search_clipping) quantizer.update_linear_quant_params(qp_dict) loss = eval_fn(quantizer.model) @@ -411,105 +429,31 @@ def ptq_coordinate_search(model, dummy_input, eval_fn, method='Powell', options= msglogger.info("Iteration %d: \t Score=%.3f" % (i, score)) if score < best_data['score']: best_data['score'] = score - best_data['qp_dict'] = quant_params_vec2dict(keys, qp_vec, args.search_clipping) + best_data['qp_dict'] = quant_params_vec2dict(keys, qp_vec, search_clipping) msglogger.info("Saving current best quantization parameters.") if test_fn: msglogger.info('Evaluating on full test set...') - l_top1, l_top5, l_loss = test_fn(quantizer.model) - msglogger.info('Test: \tloss=%.3f, top1=%.3f, top5=%.3f ' % (l_loss, l_top1, l_top5)) - - options = options or OrderedDict() - if args.maxiter is not None: - options['maxiter'] = args.maxiter - if args.maxfev is not None: - options['maxfev'] = args.maxfev + results = test_fn(quantizer.model) + s = ', '.join(['{} = {:.3f}'.format(k, v) for k, v in results.items()]) + msglogger.info('Test: ' + s) + + options = OrderedDict() + options['maxiter'] = maxiter + options['maxfev'] = maxfev + minimizer_kwargs = minimizer_kwargs or OrderedDict() minimizer_kwargs.update({ 'method': method, 'options': options }) - basinhopping = basinhopping or args.basinhopping if basinhopping: - msglogger.info('Using basinhopping global minimum search with "%s" local minimization method'% - method) - res = opt.basinhopping(feed_forward_fn, init_qp_vec, args.niter, callback=callback, + msglogger.info('Using basinhopping global minimum search with "%s" local minimization method' % method) + res = opt.basinhopping(feed_forward_fn, init_qp_vec, basinhopping_niter, callback=callback, minimizer_kwargs=minimizer_kwargs) else: msglogger.info('Using "%s" minimization algorithm.' % method) res = opt.minimize(feed_forward_fn, init_qp_vec, callback=callback, **minimizer_kwargs) - msglogger.info("Optimization done. Best configuration: %s" % best_data['qp_dict']) - return model, best_data['qp_dict'] - - -if __name__ == "__main__": - args = get_default_args() - args.epochs = float('inf') # hack for args parsing so there's no error in epochs - cc = classifier.ClassifierCompressor(args, script_dir=os.path.dirname(__file__)) - - args = deepcopy(cc.args) - - effective_test_size_bak = args.effective_test_size - args.effective_test_size = args.opt_val_size - eval_data_loader = classifier.load_data(args, load_train=False, load_val=False, load_test=True, fixed_subset=True) - - args.effective_test_size = effective_test_size_bak - test_data_loader = classifier.load_data(args, load_train=False, load_val=False, load_test=True) - - # logging - logging.getLogger().setLevel(logging.WARNING) - msglogger = logging.getLogger(__name__) - msglogger.setLevel(logging.INFO) - - model = cc.model.eval() - device = next(model.parameters()).device - - if args.memoize_dataloader: - memoized_data_loader = [] - for images, targets in eval_data_loader: - batch = images.to(device), targets.to(device) - memoized_data_loader.append(batch) - else: - memoized_data_loader = None - - def eval_fn(model): - if args.memoize_dataloader: - loss = 0 - for images, targets in memoized_data_loader: - outputs = model(images) - loss += cc.criterion(outputs, targets).item() - loss = loss / len(memoized_data_loader) - else: - _, _, loss = classifier.test(eval_data_loader, model, cc.criterion, [cc.tflogger, cc.pylogger], - None, args) - return loss - - def test_fn(model): - return classifier.test(test_data_loader, model, cc.criterion, [cc.tflogger, cc.pylogger], None, args) - - args.device = device - if args.resumed_checkpoint_path: - args.load_model_path = args.resumed_checkpoint_path - if args.load_model_path: - msglogger.info("Loading checkpoint from %s" % args.load_model_path) - model = apputils.load_lean_checkpoint(model, args.load_model_path, - model_device=args.device) - - if args.qe_stats_file: - msglogger.info("Loading stats from %s" % args.qe_stats_file) - with open(args.qe_stats_file, 'r') as f: - act_stats = distiller.yaml_ordered_load(f) - else: - act_stats = None - - dummy_input = torch.rand(*model.input_shape, device=args.device) - model, qp_dict = ptq_coordinate_search(model, dummy_input, eval_fn, args.method, - args=args, act_stats=act_stats, test_fn=test_fn) - - top1, top5, loss = test_fn(model) - - msglogger.info("Arch: %s \tTest: \t top1 = %.3f \t top5 = %.3f \t loss = %.3f" % - (args.arch, top1, top5, loss)) - distiller.yaml_ordered_save('%s.quant_params_dict.yaml' % args.arch, qp_dict) - - distiller.apputils.save_checkpoint(0, args.arch, model, extras={'top1': top1, 'qp_dict': qp_dict}, name=args.name, - dir=cc.logdir) + msglogger.info('Optimization done') + msglogger.info('Best score: {}'.format(best_data['score'])) + msglogger.info('Best Configuration: {}'.format(best_data['qp_dict'])) + return quantizer.model, best_data['qp_dict'] diff --git a/distiller/quantization/range_linear.py b/distiller/quantization/range_linear.py index ef92314..31bfc1d 100644 --- a/distiller/quantization/range_linear.py +++ b/distiller/quantization/range_linear.py @@ -268,7 +268,7 @@ def linear_dequantize_with_metadata(t, inplace=False): return t -def add_post_train_quant_args(argparser): +def add_post_train_quant_args(argparser, add_lapq_args=False): str_to_quant_mode_map = OrderedDict([ ('sym', LinearQuantMode.SYMMETRIC), ('sym_restr', LinearQuantMode.SYMMETRIC_RESTRICTED), @@ -293,8 +293,7 @@ def add_post_train_quant_args(argparser): linear_quant_mode_str_optional = partial(from_dict, d=str_to_quant_mode_map, optional=True) clip_mode_str = partial(from_dict, d=str_to_clip_mode_map, optional=False) - group = argparser.add_argument_group('Arguments controlling quantization at evaluation time ' - '("post-training quantization")') + group = argparser.add_argument_group('Post-Training Quantization Arguments') group.add_argument('--quantize-eval', '--qe', action='store_true', help='Apply linear quantization to model before evaluation. Applicable only if ' '--evaluate is also set') @@ -338,15 +337,21 @@ def add_post_train_quant_args(argparser): stats_group = group.add_mutually_exclusive_group() stats_group.add_argument('--qe-stats-file', type=str, metavar='PATH', - help='Path to YAML file with pre-made calibration stats') + help='Path to YAML file with pre-made calibration stats') stats_group.add_argument('--qe-dynamic', action='store_true', help='Apply dynamic quantization') stats_group.add_argument('--qe-calibration', type=distiller.utils.float_range_argparse_checker(exc_min=True), - metavar='PORTION_OF_TEST_SET', default=None, - help='Run the model in evaluation mode on the specified portion of the test dataset and ' - 'collect statistics') + metavar='PORTION_OF_TEST_SET', default=None, + help='Run the model in evaluation mode on the specified portion of the test dataset and ' + 'collect statistics') stats_group.add_argument('--qe-config-file', type=str, metavar='PATH', - help='Path to YAML file containing configuration for PostTrainLinearQuantizer (if present, ' - 'all other --qe* arguments are ignored)') + help='Path to YAML file containing configuration for PostTrainRLinearQuantizer ' + '(if present, all other --qe* arguments are ignored)') + + if add_lapq_args: + from .ptq_coordinate_search import add_coordinate_search_args + group.add_argument('--qe-lapq', '--qe-coordinate-search', action='store_true', + help='Optimize post-training quantization parameters using LAPQ method') + add_coordinate_search_args(argparser) class UnsatisfiedRequirements(Exception): @@ -1597,15 +1602,6 @@ class PostTrainLinearQuantizer(Quantizer): model_activation_stats = distiller.utils.yaml_ordered_load(stream) elif not isinstance(model_activation_stats, (dict, OrderedDict)): raise TypeError('model_activation_stats must either be a string, a dict / OrderedDict or None') - else: - msglogger.warning("\nWARNING:\nNo stats file passed - Dynamic quantization will be used\n" - "At the moment, this mode isn't as fully featured as stats-based quantization, and " - "the accuracy results obtained are likely not as representative of real-world results." - "\nSpecifically:\n" - " * Not all modules types are supported in this mode. Unsupported modules will remain " - "in FP32.\n" - " * Optimizations for quantization of layers followed by Relu/Tanh/Sigmoid are only " - "supported when statistics are used.\nEND WARNING\n") mode_dict = {'activations': _enum_to_str(mode.activations), 'weights': _enum_to_str(mode.weights)} self.model.quantizer_metadata = {'type': type(self), @@ -1928,6 +1924,16 @@ class PostTrainLinearQuantizer(Quantizer): msglogger.info('Per-layer quantization parameters saved to ' + save_path) def prepare_model(self, dummy_input=None): + if not self.model_activation_stats: + msglogger.warning("\nWARNING:\nNo stats file passed - Dynamic quantization will be used\n" + "At the moment, this mode isn't as fully featured as stats-based quantization, and " + "the accuracy results obtained are likely not as representative of real-world results." + "\nSpecifically:\n" + " * Not all modules types are supported in this mode. Unsupported modules will remain " + "in FP32.\n" + " * Optimizations for quantization of layers followed by Relu/Tanh/Sigmoid are only " + "supported when statistics are used.\nEND WARNING\n") + self.has_bidi_distiller_lstm = any(isinstance(m, distiller.modules.DistillerLSTM) and m.bidirectional for _, m in self.model.named_modules()) if self.has_bidi_distiller_lstm: diff --git a/examples/classifier_compression/README.md b/examples/classifier_compression/README.md index 2c72e02..f893b2d 100644 --- a/examples/classifier_compression/README.md +++ b/examples/classifier_compression/README.md @@ -34,6 +34,7 @@ A non-exhaustive list of the methods implemented: ### Quantization - [Post-training quantization](https://github.com/NervanaSystems/distiller/tree/master/examples/quantization/post_train_quant/command_line.md) based on the TensorFlow quantization scheme (originally GEMMLOWP) with additional capabilities. + - Optimizing post-training quantization parameters with the [LAPQ](https://arxiv.org/abs/1911.07190) method - see [example YAML](https://github.com/NervanaSystems/distiller/blob/master/examples/quantization/post_train_quant/resnet18_imagenet_post_train_lapq.yaml) file for details. - [Quantization-aware training](https://github.com/NervanaSystems/distiller/tree/master/examples/quantization/quant_aware_train): TensorFlow scheme, DoReFa, PACT ### Knowledge Distillation diff --git a/examples/classifier_compression/compress_classifier.py b/examples/classifier_compression/compress_classifier.py index a6904f7..88911c4 100755 --- a/examples/classifier_compression/compress_classifier.py +++ b/examples/classifier_compression/compress_classifier.py @@ -61,6 +61,7 @@ import distiller.apputils as apputils import parser import os import numpy as np +from ptq_lapq import image_classifier_ptq_lapq # Logger handle @@ -69,7 +70,7 @@ msglogger = logging.getLogger() def main(): # Parse arguments - args = parser.add_cmdline_args(classifier.init_classifier_compression_arg_parser()).parse_args() + args = parser.add_cmdline_args(classifier.init_classifier_compression_arg_parser(True)).parse_args() app = ClassifierCompressorSampleApp(args, script_dir=os.path.dirname(__file__)) if app.handle_subapps(): return @@ -110,10 +111,13 @@ def handle_subapps(model, criterion, optimizer, compression_scheduler, pylogger, sensitivity_analysis(model, criterion, test_loader, pylogger, args, sensitivities) do_exit = True elif args.evaluate: - test_loader = load_test_data(args) - classifier.evaluate_model(test_loader, model, criterion, pylogger, - classifier.create_activation_stats_collectors(model, *args.activation_stats), - args, scheduler=compression_scheduler) + if args.quantize_eval and args.qe_lapq: + image_classifier_ptq_lapq(model, criterion, pylogger, args) + else: + test_loader = load_test_data(args) + classifier.evaluate_model(test_loader, model, criterion, pylogger, + classifier.create_activation_stats_collectors(model, *args.activation_stats), + args, scheduler=compression_scheduler) do_exit = True elif args.thinnify: assert args.resumed_checkpoint_path is not None, \ diff --git a/examples/classifier_compression/parser.py b/examples/classifier_compression/parser.py index 0385d0e..5697787 100755 --- a/examples/classifier_compression/parser.py +++ b/examples/classifier_compression/parser.py @@ -16,7 +16,7 @@ import argparse import distiller -import distiller.quantization +import distiller.pruning import distiller.models as models diff --git a/examples/classifier_compression/ptq_lapq.py b/examples/classifier_compression/ptq_lapq.py new file mode 100644 index 0000000..fa79305 --- /dev/null +++ b/examples/classifier_compression/ptq_lapq.py @@ -0,0 +1,99 @@ +# +# Copyright (c) 2020 Intel Corporation +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import os +import torch +from copy import deepcopy +import logging +from collections import OrderedDict + +import distiller +import distiller.apputils as apputils +import distiller.apputils.image_classifier as classifier +import distiller.quantization.ptq_coordinate_search as lapq + + +msglogger = logging.getLogger() + + +def image_classifier_ptq_lapq(model, criterion, loggers, args): + args = deepcopy(args) + + effective_test_size_bak = args.effective_test_size + args.effective_test_size = args.lapq_eval_size + eval_data_loader = classifier.load_data(args, load_train=False, load_val=False, load_test=True, fixed_subset=True) + + args.effective_test_size = effective_test_size_bak + test_data_loader = classifier.load_data(args, load_train=False, load_val=False, load_test=True) + + model = model.eval() + device = next(model.parameters()).device + + if args.lapq_eval_memoize_dataloader: + images_batches = [] + targets_batches = [] + for images, targets in eval_data_loader: + images_batches.append(images.to(device)) + targets_batches.append(targets.to(device)) + memoized_data_loader = [(torch.cat(images_batches), torch.cat(targets_batches))] + else: + memoized_data_loader = None + + def eval_fn(model): + if memoized_data_loader: + loss = 0 + for images, targets in memoized_data_loader: + outputs = model(images) + loss += criterion(outputs, targets).item() + loss = loss / len(memoized_data_loader) + else: + _, _, loss = classifier.test(eval_data_loader, model, criterion, loggers, None, args) + return loss + + def test_fn(model): + top1, top5, loss = classifier.test(test_data_loader, model, criterion, loggers, None, args) + return OrderedDict([('top-1', top1), ('top-5', top5), ('loss', loss)]) + + args.device = device + if args.resumed_checkpoint_path: + args.load_model_path = args.resumed_checkpoint_path + if args.load_model_path: + msglogger.info("Loading checkpoint from %s" % args.load_model_path) + model = apputils.load_lean_checkpoint(model, args.load_model_path, + model_device=args.device) + + quantizer = distiller.quantization.PostTrainLinearQuantizer.from_args(model, args) + + dummy_input = torch.rand(*model.input_shape, device=args.device) + model, qp_dict = lapq.ptq_coordinate_search(quantizer, dummy_input, eval_fn, test_fn=test_fn, + **lapq.cmdline_args_to_dict(args)) + + results = test_fn(quantizer.model) + msglogger.info("Arch: %s \tTest: \t top1 = %.3f \t top5 = %.3f \t loss = %.3f" % + (args.arch, results['top-1'], results['top-5'], results['loss'])) + distiller.yaml_ordered_save('%s.quant_params_dict.yaml' % args.arch, qp_dict) + + distiller.apputils.save_checkpoint(0, args.arch, model, + extras={'top1': results['top-1'], 'qp_dict': qp_dict}, name=args.name, + dir=msglogger.logdir) + + +if __name__ == "__main__": + parser = classifier.init_classifier_compression_arg_parser(include_ptq_lapq_args=True) + args = parser.parse_args() + args.epochs = float('inf') # hack for args parsing so there's no error in epochs + cc = classifier.ClassifierCompressor(args, script_dir=os.path.dirname(__file__)) + image_classifier_ptq_lapq(cc.model, cc.criterion, [cc.pylogger, cc.tflogger], cc.args) diff --git a/examples/quantization/post_train_quant/command_line.md b/examples/quantization/post_train_quant/command_line.md index ce2c948..2c112eb 100644 --- a/examples/quantization/post_train_quant/command_line.md +++ b/examples/quantization/post_train_quant/command_line.md @@ -30,8 +30,9 @@ Post-training quantization can either be configured straight from the command-li | `--qe-stats-file` | N/A | Use stats file for static quantization of activations. See details below | None | | `--qe-dynamic` | N/A | Perform dynamic quantization. See details below | None | | `--qe-config-file` | N/A | Path to YAML config file. See section above. (ignores all other --qe* arguments) | None | -| `--qe-convert-pytorch` | `--qept` | Convert the model to PyTorch native post-train quantization modules | Off | +| `--qe-convert-pytorch` | `--qept` | Convert the model to PyTorch native post-train quantization modules. See [tutorial](https://github.com/NervanaSystems/distiller/blob/master/jupyter/post_train_quant_convert_pytorch.ipynb) for more details | Off | | `--qe-pytorch-backend` | N/A | When --qe-convert-pytorch is set, specifies the PyTorch quantization backend to use. Choices: "fbgemm", "qnnpack" | Off | +| `--qe-lapq` | N/A | Optimize post-training quantization parameters using [LAPQ](https://arxiv.org/abs/1911.07190) method. Beyond the scope of this document. See [example YAML](https://github.com/NervanaSystems/distiller/blob/master/examples/quantization/post_train_quant/resnet18_imagenet_post_train_lapq.yaml) file for details | Off | ### Notes @@ -40,10 +41,6 @@ Post-training quantization can either be configured straight from the command-li * `--quantize-eval` is also set, in which case an FP32 model is first quantized using Distiller's post-training quantization flow, and then converted to a PyTorch native quantization model. * `--quantize-eval` is not set, but a previously post-train quantized model is loaded via `--resume`. In this case, the loaded model is converted to PyTorch native quantization. -### Conversion to PyTorch Built-in Quantization Model - -PyTorch released built-in support for quantization in version 1.3. Currently Distiller's quantization functionality is still completely separate from PyTorch's. We provide the ability to take a model which was post-train quantized with Distiller, and is comprised of `RangeLinearQuantWrapper` - ## "Net-Aware" Quantization The term "net-aware" quantization, coined in [this](https://arxiv.org/abs/1811.09886) paper from Facebook (section 3.2.2), means we can achieve better quantization by considering sequences of operations instead of just quantizing each operation independently. This isn't exactly layer fusion - in Distiller we modify activation stats prior to setting quantization parameters, in to make sure that when a module is followed by certain activation functions, only the relevant ranges are quantized. We do this for: diff --git a/examples/quantization/post_train_quant/resnet18_imagenet_post_train_lapq.yaml b/examples/quantization/post_train_quant/resnet18_imagenet_post_train_lapq.yaml index 87b42f2..b920326 100644 --- a/examples/quantization/post_train_quant/resnet18_imagenet_post_train_lapq.yaml +++ b/examples/quantization/post_train_quant/resnet18_imagenet_post_train_lapq.yaml @@ -21,7 +21,7 @@ quantizers: mode: activations: ASYMMETRIC_UNSIGNED weights: SYMMETRIC - model_activation_stats: ../../examples/quantization/post_train_quant/stats/resnet18_quant_stats.yaml + # model_activation_stats: ../quantization/post_train_quant/stats/resnet18_quant_stats.yaml per_channel_wts: False inputs_quant_auto_fallback: True @@ -47,12 +47,13 @@ quantizers: # Example invocations: # * Preliminaries: -# cd <distiller_root>/distiller/quantization -# CONFIG_FILE="../../examples/quantization/post_train_quant/resnet18_imagenet_post_train_lapq.yaml" +# cd <distiller_root>/examples/classifier_compression +# CONFIG_FILE="../quantization/post_train_quant/resnet18_imagenet_post_train_lapq.yaml" +# IMAGENET_PATH=<path_to_imagenet> # # * Using L3 initialization: # Command: -# python ptq_coordinate_search.py -a resnet18 --pretrained <path_to_imagenet> --opt-val-size 0.01 --opt-maxiter 2 --qe-config-file $CONFIG_FILE -b 500 --opt-init-mode L3 --opt-init-method powell --opt-eval-memoize-dataloader --det --opt-search-clipping +# python compress_classifier.py --eval --qe --qe-lapq -a resnet18 --pretrained $IMAGENET_PATH --lapq-eval-size 0.01 --lapq-maxiter 2 --qe-config-file $CONFIG_FILE -b 500 --lapq-init-mode L3 --lapq-init-method powell --lapq-eval-memoize-dataloader --det --lapq-search-clipping # # Excerpts from output: # ... @@ -77,7 +78,7 @@ quantizers: # # * Using LAPLACE initialization: # Command: -# python ptq_coordinate_search.py -a resnet18 --pretrained <path_to_imagenet> --opt-val-size 0.01 --opt-maxiter 2 --qe-config-file $CONFIG_FILE -b 500 --opt-init-mode LAPLACE --opt-init-method powell --opt-eval-memoize-dataloader --det --opt-search-clipping +# python compress_classifier.py --eval --qe --qe-lapq -a resnet18 --pretrained $IMAGENET_PATH --lapq-eval-size 0.01 --lapq-maxiter 2 --qe-config-file $CONFIG_FILE -b 500 --lapq-init-mode LAPLACE --lapq-init-method powell --lapq-eval-memoize-dataloader --det --lapq-search-clipping # # Excerpts from output: # ... -- GitLab