From 0b493fd30af03550d6e1e33c9d83c0f3f83b90c5 Mon Sep 17 00:00:00 2001 From: Lev Zlotnik <46742999+levzlotnik@users.noreply.github.com> Date: Sun, 2 Feb 2020 12:22:10 +0200 Subject: [PATCH] Loss Aware Post Train Quantization search (#432) "Loss Aware Post-Training Quantization" (Nahshan et al., 2019) Paper: https://arxiv.org/abs/1911.07190 Reference implementation: https://github.com/ynahshan/nn-quantization-pytorch/tree/master/lapq Proper documentation is still TODO, for now see the example YAML file at 'examples/quantization/post_train_quant/resnet18_imagenet_post_train_lapq.yaml' * Implemented in distiller/quantization/ptq_coordinate_search.py * At the moment that file both the model-independent algorithm implementation and image-classification specific sample script. Still TODO: Refactor that * Post train quantization changes (range_linear): * Added getters/setters for quantization parameters (scale/zero_point) and clipping values * Add option to save backup of FP32 weights to allow re-quantization after quantizer was created. * Add option to clip weights in addition to activations * Fix fusions to not occur only when activations aren't quantized * RangeLinearFakeQuantWrapper: * Make inputs quantization optional * In case of ReLU + ACIQ, clip according to input stats * Data loaders: * Add option to not load train set at all from disk (to speed up loading time in post-training runs) * Modified "image_classifier.py" accordingly --- distiller/apputils/data_loaders.py | 134 +++-- distiller/apputils/image_classifier.py | 13 +- .../quantization/ptq_coordinate_search.py | 515 ++++++++++++++++++ distiller/quantization/quantizer.py | 5 + distiller/quantization/range_linear.py | 500 ++++++++++++++--- distiller/utils.py | 10 +- .../resnet18_imagenet_post_train_lapq.yaml | 101 ++++ tests/test_post_train_quant.py | 55 +- 8 files changed, 1171 insertions(+), 162 deletions(-) create mode 100644 distiller/quantization/ptq_coordinate_search.py create mode 100644 examples/quantization/post_train_quant/resnet18_imagenet_post_train_lapq.yaml diff --git a/distiller/apputils/data_loaders.py b/distiller/apputils/data_loaders.py index 2394aa9..ae18998 100755 --- a/distiller/apputils/data_loaders.py +++ b/distiller/apputils/data_loaders.py @@ -66,7 +66,7 @@ def __dataset_factory(dataset): def load_data(dataset, data_dir, batch_size, workers, validation_split=0.1, deterministic=False, effective_train_size=1., effective_valid_size=1., effective_test_size=1., - fixed_subset=False, sequential=False): + fixed_subset=False, sequential=False, test_only=False): """Load a dataset. Args: @@ -94,29 +94,34 @@ def load_data(dataset, data_dir, batch_size, workers, validation_split=0.1, dete effective_valid_size=effective_valid_size, effective_test_size=effective_test_size, fixed_subset=fixed_subset, - sequential=sequential) + sequential=sequential, + test_only=test_only) -def mnist_get_datasets(data_dir): +def mnist_get_datasets(data_dir, load_train=True, load_test=True): """Load the MNIST dataset.""" - train_transform = transforms.Compose([ - transforms.ToTensor(), - transforms.Normalize((0.1307,), (0.3081,)) - ]) - train_dataset = datasets.MNIST(root=data_dir, train=True, - download=True, transform=train_transform) - - test_transform = transforms.Compose([ - transforms.ToTensor(), - transforms.Normalize((0.1307,), (0.3081,)) - ]) - test_dataset = datasets.MNIST(root=data_dir, train=False, - transform=test_transform) + train_dataset = None + if load_train: + train_transform = transforms.Compose([ + transforms.ToTensor(), + transforms.Normalize((0.1307,), (0.3081,)) + ]) + train_dataset = datasets.MNIST(root=data_dir, train=True, + download=True, transform=train_transform) + + test_dataset = None + if load_test: + test_transform = transforms.Compose([ + transforms.ToTensor(), + transforms.Normalize((0.1307,), (0.3081,)) + ]) + test_dataset = datasets.MNIST(root=data_dir, train=False, + transform=test_transform) return train_dataset, test_dataset -def cifar10_get_datasets(data_dir): +def cifar10_get_datasets(data_dir, load_train=True, load_test=True): """Load the CIFAR10 dataset. The original training dataset is split into training and validation sets (code is @@ -134,28 +139,32 @@ def cifar10_get_datasets(data_dir): [1] C.-Y. Lee, S. Xie, P. Gallagher, Z. Zhang, and Z. Tu. Deeply Supervised Nets. arXiv:1409.5185, 2014 """ - train_transform = transforms.Compose([ - transforms.RandomCrop(32, padding=4), - transforms.RandomHorizontalFlip(), - transforms.ToTensor(), - transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)) - ]) - - train_dataset = datasets.CIFAR10(root=data_dir, train=True, - download=True, transform=train_transform) - - test_transform = transforms.Compose([ - transforms.ToTensor(), - transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)) - ]) - - test_dataset = datasets.CIFAR10(root=data_dir, train=False, - download=True, transform=test_transform) + train_dataset = None + if load_train: + train_transform = transforms.Compose([ + transforms.RandomCrop(32, padding=4), + transforms.RandomHorizontalFlip(), + transforms.ToTensor(), + transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)) + ]) + + train_dataset = datasets.CIFAR10(root=data_dir, train=True, + download=True, transform=train_transform) + + test_dataset = None + if load_test: + test_transform = transforms.Compose([ + transforms.ToTensor(), + transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)) + ]) + + test_dataset = datasets.CIFAR10(root=data_dir, train=False, + download=True, transform=test_transform) return train_dataset, test_dataset -def imagenet_get_datasets(data_dir): +def imagenet_get_datasets(data_dir, load_train=True, load_test=True): """ Load the ImageNet dataset. """ @@ -164,23 +173,27 @@ def imagenet_get_datasets(data_dir): normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) - train_transform = transforms.Compose([ - transforms.RandomResizedCrop(224), - transforms.RandomHorizontalFlip(), - transforms.ToTensor(), - normalize, - ]) + train_dataset = None + if load_train: + train_transform = transforms.Compose([ + transforms.RandomResizedCrop(224), + transforms.RandomHorizontalFlip(), + transforms.ToTensor(), + normalize, + ]) - train_dataset = datasets.ImageFolder(train_dir, train_transform) + train_dataset = datasets.ImageFolder(train_dir, train_transform) - test_transform = transforms.Compose([ - transforms.Resize(256), - transforms.CenterCrop(224), - transforms.ToTensor(), - normalize, - ]) + test_dataset = None + if load_test: + test_transform = transforms.Compose([ + transforms.Resize(256), + transforms.CenterCrop(224), + transforms.ToTensor(), + normalize, + ]) - test_dataset = datasets.ImageFolder(test_dir, test_transform) + test_dataset = datasets.ImageFolder(test_dir, test_transform) return train_dataset, test_dataset @@ -263,14 +276,25 @@ def _get_sampler(data_source, effective_size, fixed_subset=False, sequential=Fal def get_data_loaders(datasets_fn, data_dir, batch_size, num_workers, validation_split=0.1, deterministic=False, effective_train_size=1., effective_valid_size=1., effective_test_size=1., fixed_subset=False, - sequential=False): - train_dataset, test_dataset = datasets_fn(data_dir) + sequential=False, test_only=False): + train_dataset, test_dataset = datasets_fn(data_dir, load_train=not test_only, load_test=True) worker_init_fn = None if deterministic: distiller.set_deterministic() worker_init_fn = __deterministic_worker_init_fn + test_indices = list(range(len(test_dataset))) + test_sampler = _get_sampler(test_indices, effective_test_size, fixed_subset, sequential) + test_loader = torch.utils.data.DataLoader(test_dataset, + batch_size=batch_size, sampler=test_sampler, + num_workers=num_workers, pin_memory=True) + + input_shape = __image_size(test_dataset) + + if test_only: + return None, None, test_loader, input_shape + num_train = len(train_dataset) indices = list(range(num_train)) @@ -296,13 +320,5 @@ def get_data_loaders(datasets_fn, data_dir, batch_size, num_workers, validation_ num_workers=num_workers, pin_memory=True, worker_init_fn=worker_init_fn) - test_indices = list(range(len(test_dataset))) - test_sampler = _get_sampler(test_indices, effective_test_size, fixed_subset, sequential) - test_loader = torch.utils.data.DataLoader(test_dataset, - batch_size=batch_size, sampler=test_sampler, - num_workers=num_workers, pin_memory=True) - - input_shape = __image_size(train_dataset) - # If validation split was 0 we use the test set as the validation set return train_loader, valid_loader or test_loader, test_loader, input_shape diff --git a/distiller/apputils/image_classifier.py b/distiller/apputils/image_classifier.py index 172bfeb..3dd6339 100755 --- a/distiller/apputils/image_classifier.py +++ b/distiller/apputils/image_classifier.py @@ -470,13 +470,18 @@ def save_collectors_data(collectors, directory): def load_data(args, fixed_subset=False, sequential=False, load_train=True, load_val=True, load_test=True): - train_loader, val_loader, test_loader, _ = apputils.load_data(args.dataset, + test_only = not load_train and not load_val + + train_loader, val_loader, test_loader, _ = apputils.load_data(args.dataset, os.path.expanduser(args.data), args.batch_size, args.workers, args.validation_split, args.deterministic, args.effective_train_size, args.effective_valid_size, args.effective_test_size, - fixed_subset, sequential) - msglogger.info('Dataset sizes:\n\ttraining=%d\n\tvalidation=%d\n\ttest=%d', - len(train_loader.sampler), len(val_loader.sampler), len(test_loader.sampler)) + fixed_subset, sequential, test_only) + if test_only: + msglogger.info('Dataset sizes:\n\ttest=%d', len(test_loader.sampler)) + else: + msglogger.info('Dataset sizes:\n\ttraining=%d\n\tvalidation=%d\n\ttest=%d', + len(train_loader.sampler), len(val_loader.sampler), len(test_loader.sampler)) loaders = (train_loader, val_loader, test_loader) flags = (load_train, load_val, load_test) diff --git a/distiller/quantization/ptq_coordinate_search.py b/distiller/quantization/ptq_coordinate_search.py new file mode 100644 index 0000000..cd2f4df --- /dev/null +++ b/distiller/quantization/ptq_coordinate_search.py @@ -0,0 +1,515 @@ +# +# Copyright (c) 2019 Intel Corporation +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +# +# Implementation of "Loss Aware Post-Training Quantization" (Nahshan et al., 2019) +# +# Paper: https://arxiv.org/abs/1911.07190 +# Reference implementation: https://github.com/ynahshan/nn-quantization-pytorch/tree/master/lapq +# + +import torch +import torch.nn as nn +from distiller.quantization.range_linear import PostTrainLinearQuantizer, ClipMode, \ + RangeLinearQuantWrapper, RangeLinearEmbeddingWrapper, RangeLinearQuantParamLayerWrapper, \ + is_post_train_quant_wrapper, LinearQuantMode +from distiller.quantization import is_linear_quant_mode_asymmetric, is_linear_quant_mode_symmetric +from functools import partial +from distiller.summary_graph import SummaryGraph +from distiller.model_transforms import fold_batch_norms +import distiller.modules +from distiller.data_loggers import collect_quant_stats +from collections import OrderedDict +from itertools import count +import logging +from copy import deepcopy +import distiller.apputils.image_classifier as classifier +import os +import distiller.apputils as apputils +import scipy.optimize as opt +import numpy as np + + +def quant_params_dict2vec(p_dict, search_clipping=False): + """ + Convert quantization params dictionary returned by post-train quantizer to a numpy array that can be used + with scipy.opt.minimize + """ + keys = [] + vals = [] + for k, v in p_dict.items(): + if search_clipping and isinstance(v, tuple): + # When both min and amx values are optimized, we need to concatenate them in the array + # We create dual matching keys so it's easy to convert back to a dict + keys += [k + '_min', k + '_max'] + vals += [v[0].item(), v[1].item()] + else: + keys.append(k) + vals.append(v.item()) + + return keys, np.array(vals) + + +def quant_params_vec2dict(keys, vals, search_clipping=False): + """ + Convert the vector(s) created by quant_params_dict2vec to a dictionary of quantization parameters that + the post-training quantizer API can digest + """ + res = OrderedDict() + for idx, k in enumerate(keys): + if search_clipping and k.endswith('_min'): + res[k[:-4]] = sorted((vals[idx], vals[idx + 1])) + elif search_clipping and k.endswith('_max'): + continue + else: + res[k] = abs(vals[idx]) + return res + + +def lp_loss(x: torch.Tensor, y: torch.Tensor, p): + tmp = (x - y).abs_().pow_(p) + loss = (torch.sum(tmp) / x.numel()).item() + return loss + + +def _check_qp_vec(keys, qp_vec, quant_mode=LinearQuantMode.SYMMETRIC, search_clipping=False): + if is_linear_quant_mode_symmetric(quant_mode): + return all(qp_vec > 0) + if not search_clipping: + idxs_scales = np.array(['scale' in key for key in keys]) + qp_vec_scales = qp_vec[idxs_scales] + return all(qp_vec_scales > 0) + + +l1_loss = partial(lp_loss, p=1) +l2_loss = partial(lp_loss, p=2) +l3_loss = partial(lp_loss, p=3) + + +_INIT_MODES = { + 'NONE': ClipMode.NONE, 'AVG': ClipMode.AVG, 'LAPLACE': ClipMode.LAPLACE, 'GAUSS': ClipMode.GAUSS, + 'L1': l1_loss, 'L2': l2_loss, 'L3': l3_loss +} + + +def _init_mode_from_str(init_mode_str): + if init_mode_str not in _INIT_MODES: + raise ValueError('Unsupported init mode \'%s\'. ' + 'The supported init modes are: %s.' % (init_mode_str, _INIT_MODES)) + return _INIT_MODES[init_mode_str] + + +def optimize_for_layer(layer, quantized_layer, loss_fn, input, method=None, search_clipping=False): + """ + Searches for optimal linear quantization parameters (scale, zero_point) for a layer + with respect to the loss function. Assumes loss_fn is of the signature `loss_fn(y, y_q)->float` + + We perform the initialization a bit differently compared to the paper/reference implementation: + * In the reference: + * Weights and activations are initialized based on quantization loss of their respective tensors. + * Activations are initialized "online", meaning the input to the layer N being initialized is the + output of the already quantized layer N-1. + * In this implementation: + * For a given layer, we initialize both activations and weights together (as applicable) based on the + LP loss between the quantized layer output and the FP32 layer output. + * But, we don't do "online" initialization. That is, each layer is initialized independently from the + quantization parameters obtained for earlier layers. + + Args: + layer (nn.Module): the original, pre-quantized, layer. + quantized_layer (RangeLinearQuantWrapper or RangeLinearEmbeddingWrapper): the post-quantized layer. + loss_fn (callable): the loss function to optimize with respect to it. + method (str or callable): the method of optimization, as will be used by scipy.optimize.minimize. + search_clipping (bool): if set, optimize clipping values, otherwise optimize scale factor + Returns: + quantized_layer after optimization + """ + params_gen = quantized_layer.named_linear_quant_params(filter=True) if not search_clipping \ + else quantized_layer.named_clipping(filter=True) + init_qp_dict = OrderedDict(params_gen) + + keys, init_qp_vec = quant_params_dict2vec(init_qp_dict, search_clipping) + + def feed_forward_fn(qp_vec): + qp_dict = quant_params_vec2dict(keys, qp_vec, search_clipping) + quantized_layer.update_linear_quant_params(qp_dict) + # Using cloned input, required if the layer is inplace + y = layer(input.clone().detach()) + if getattr(quantized_layer, 'clip_half_range', False): + torch.relu_(y) + q_y = quantized_layer(input.clone().detach()) + loss = loss_fn(y, q_y) + return loss + + result = opt.minimize(feed_forward_fn, init_qp_vec, method=method) # type: opt.OptimizeResult + return quantized_layer + + +def get_input_for_layer(model, layer_name, eval_fn): + layer = dict(model.named_modules())[layer_name] # type: nn.Module + layer_inputs = [] + + def hook_layer_input(module, input): + layer_inputs.append(input[0].clone().detach()) + + handle = layer.register_forward_pre_hook(hook_layer_input) + eval_fn(model) + assert len(layer_inputs) == 1 + handle.remove() + return layer_inputs[0] + + +def init_layer_linear_quant_params(quantizer, original_model, layer_name, init_mode, + init_mode_method='Powell', eval_fn=None, search_clipping=False): + """ + Initializes a layer's linear quant parameters. + This is done to set the scipy.optimize.minimize initial guess. + Args: + quantizer (PostTrainLinearQuantizer): the quantizer, **after** calling prepare model. + original_model (nn.Module): the original, pre-quantized, model. + layer_name (str): the name of the layer. + init_mode (ClipMode or callable or str): the initialization mode. + If ClipMode, the initialization will be according to the respective ClipMode. + If callable - init_mode will be treated as a loss function between the activations pre and post-quantization, + and the initialization process will attempt to find the minimum of that loss function. + E.g. if l1_loss has been passed, the initialization vector will be + scale, zero_point = argmin_{s, zp} (l1_loss(layer(input), q_layer(input; s, zp))) + If str - the mode will be chosen from a list of options. The options are: + [NONE, AVG, LAPLACE, GAUSS, L1, L2 ,L3]. + Defaults to ClipMode.NONE + init_mode_method (str or callable): applicable only in the case of init_mode = 'L1/2/3' or callable. + chooses the minimization method for finding the local argmin_{s, zp}. + Defaults to 'Powell' + eval_fn: evaluation function for the model. Assumed it has a signature of the form + `eval_fn(model)->float`. this is the function to be minimized by the optimization algorithm. + applicable only in the case of init_mode = 'L1/2/3' or callable. + search_clipping (bool): if set, optimize clipping values, otherwise optimize scale factor + """ + denorm_layer_name = distiller.denormalize_module_name(quantizer.model, layer_name) + msglogger.info(denorm_layer_name) + if isinstance(init_mode, str): + init_mode = _init_mode_from_str(init_mode) + layer = dict(original_model.named_modules())[layer_name] + local_args, local_kwargs = quantizer.modules_processed_args[denorm_layer_name] + if isinstance(init_mode, ClipMode): + local_kwargs['clip_acts'] = init_mode + replace_fn = quantizer.replacement_factory.get(type(layer), quantizer.default_repalcement_fn) + quantized_layer = replace_fn(deepcopy(layer), *local_args, **local_kwargs).eval() + if not is_post_train_quant_wrapper(quantized_layer, False): + # the module wasn't quantized, nothing to do here + return + + if callable(init_mode): + input_for_layer = get_input_for_layer(original_model, layer_name, eval_fn) + quantized_layer = optimize_for_layer(layer, quantized_layer, init_mode, input_for_layer, init_mode_method, + search_clipping=search_clipping) + + distiller.model_setattr(quantizer.model, denorm_layer_name, quantized_layer) + quantizer.model.eval() + + +def init_linear_quant_params(quantizer, original_model, eval_fn, dummy_input, init_mode, + init_mode_method=None, search_clipping=False): + """ + Initializes all linear quantization parameters of the model. + Args: + quantizer (PostTrainLinearQuantizer): the quantizer, **after** calling prepare model. + original_model (nn.Module): the original, pre-quantized, model. + init_mode (ClipMode or callable or str or dict): See `init_layer_linear_qaunt_params`. + if init_mode is dict - init_mode is configuration for the different layers, + i.e. init_mode = Dict[layer_name:str, init_mode_layer: ClipMode or callable or str]. + eval_fn: evaluation function for the model. Assumed it has a signature of the form + `eval_fn(model)->float`. this is the function to be minimized by the optimization algorithm. + Note - unlike in `init_layer_linear_quant_params`, this argument is required here. + dummy_input: dummy sample input to the model + init_mode_method: See `init_layer_linear_qaunt_params`. + search_clipping (bool): if set, optimize clipping values, otherwise optimize scale factor + """ + original_model = distiller.make_non_parallel_copy(original_model) + layers_topological_order = SummaryGraph(original_model, dummy_input).layers_topological_order() + q_named_modules = OrderedDict(quantizer.model.named_modules()) + for module_name in layers_topological_order: + # check to see if it was quantized: + q_module = q_named_modules[distiller.denormalize_module_name(quantizer.model, module_name)] + if not is_post_train_quant_wrapper(q_module, False): + continue + module_init_mode = init_mode[module_name] if isinstance(init_mode, dict) else init_mode + msglogger.debug('Initializing layer \'%s\' using %s mode' % (module_name, module_init_mode)) + init_layer_linear_quant_params(quantizer, original_model, module_name, module_init_mode, + init_mode_method=init_mode_method, + eval_fn=eval_fn, + search_clipping=search_clipping) + del original_model + + quantizer._post_prepare_model() + quantizer.model.eval() + + +def get_default_args(): + parser = classifier.init_classifier_compression_arg_parser() + parser.add_argument('--opt-maxiter', dest='maxiter', default=None, type=int, + help='Max iteration for minimization method.') + parser.add_argument('--opt-maxfev', dest='maxfev', default=None, type=int, + help='Max iteration for minimization method.') + parser.add_argument('--opt-method', dest='method', default='Powell', + help='Minimization method used by scip.optimize.minimize.') + parser.add_argument('--opt-bh', dest='basinhopping', action='store_true', default=False, + help='Use scipy.optimize.basinhopping stochastic global minimum search.') + parser.add_argument('--opt-bh-niter', dest='niter', default=100, + help='Number of iterations for the basinhopping algorithm.') + parser.add_argument('--opt-init-mode', dest='init_mode', default='NONE', + choices=list(_INIT_MODES), + help='The mode of quant initalization. Choices: ' + '|'.join(list(_INIT_MODES))) + parser.add_argument('--opt-init-method', dest='init_mode_method', + help='If --opt-init-mode was specified as L1/L2/L3, this specifies the method of ' + 'minimization.') + parser.add_argument('--opt-val-size', type=float, default=1, + help='Use portion of the test size.') + parser.add_argument('--opt-eval-memoize-dataloader', dest='memoize_dataloader', action='store_true', default=False, + help='Stores the input batch in memory to optimize performance.') + parser.add_argument('--base-score', type=float, default=None) + parser.add_argument('--opt-search-clipping', dest='search_clipping', action='store_true', + help='Search on clipping values instead of scale/zero_point.') + args = parser.parse_args() + return args + + +def validate_quantization_settings(args, quantized_model): + if args.search_clipping: + return + for n, m in quantized_model.named_modules(): + if not is_post_train_quant_wrapper(m, False): + continue + + err_msg = 'Detected asymmetric quantization of {}. ' \ + 'Switch to symmetric quantization or enable search_clipping.' + if not isinstance(m, RangeLinearEmbeddingWrapper): + if m.output_quant_settings.num_bits and \ + is_linear_quant_mode_asymmetric(m.mode.activations) and \ + not m.clip_half_range: + raise ValueError(err_msg.format('activations without fused ReLU')) + if isinstance(m, (RangeLinearEmbeddingWrapper, RangeLinearQuantParamLayerWrapper)): + if is_linear_quant_mode_asymmetric(m.mode.weights): + raise ValueError(err_msg.format('weights')) + + +def ptq_coordinate_search(model, dummy_input, eval_fn, method='Powell', options=None, + act_stats=None, args=None, fold_sequences=True, basinhopping=False, + init_args=None, minimizer_kwargs=None, + test_fn=None): + """ + Searches for the optimal post-train quantization configuration (scale/zero_points) + for a model using numerical methods, as described by scipy.optimize.minimize. + Args: + model (nn.Module): model to quantize + dummy_input: an sample expected input to the model + eval_fn (callable): evaluation function for the model. Assumed it has a signature of the form + `eval_fn(model)->float`. this is the function to be minimized by the optimization algorithm. + method (str or callable): minimization method as accepted by scipy.optimize.minimize. + options (dict or None): options for the scipy optimizer + act_stats (OrderedDict): dictionary of statistics per layer, including inputs and outputs. + for more context refer to collect_quant_stats. + args: arguments from command-line. + fold_sequences (bool): flag, indicates to fold sequences before performing the search. + basinhopping (bool): flag, indicates to use basinhopping as a global-minimization method, + will pass the `method` argument to `scipy.optimize.basinhopping`. + init_args (tuple): arguments for initializing the linear quantization parameters. + Refer to `init_linear_quant_params` for more details. + minimizer_kwargs (dict): the kwargs for scipy.optimize.minimize procedure. + test_fn (callable): a function to test the current performance of the model. + """ + if fold_sequences: + model = fold_batch_norms(model, dummy_input) + if args is None: + args = get_default_args() + elif isinstance(args, dict): + updated_args = get_default_args() + updated_args.__dict__.update(args) + args = updated_args + original_model = deepcopy(model) + + if not act_stats and not args.qe_config_file: + msglogger.info('Collecting stats for model...') + model_temp = distiller.utils.make_non_parallel_copy(model) + act_stats = collect_quant_stats(model_temp, eval_fn) + del model_temp + if args: + act_stats_path = '%s_act_stats.yaml' % args.arch + msglogger.info('Done. Saving act stats into %s' % act_stats_path) + distiller.yaml_ordered_save(act_stats_path, act_stats) + args.qe_stats_file = act_stats_path + + # Preparing model and init conditions: + msglogger.info("Initializing quantizer...") + quantizer = PostTrainLinearQuantizer.from_args(model, args) + + # Make sure weights are re-quantizable and clip-able + quantizer.save_fp_weights = True + quantizer.also_clip_weights = True + + # Disable any user set activations clipping - we'll be using init_args + quantizer.clip_acts = ClipMode.NONE + for overrides_dict in quantizer.module_overrides_map.values(): + overrides_dict.pop('clip_acts', None) + + quantizer.prepare_model(dummy_input) + quantizer.model.eval() + + validate_quantization_settings(args, quantizer.model) + + msglogger.info("Initializing quantization parameters...") + init_args = init_args or (args.init_mode, args.init_mode_method) + init_linear_quant_params(quantizer, original_model, eval_fn, dummy_input, *init_args, + search_clipping=args.search_clipping) + + msglogger.info("Evaluating initial quantization score...") + best_data = { + 'score': eval_fn(model), + 'qp_dict': deepcopy(quantizer.linear_quant_params) + } + msglogger.info("Evaluation set loss after initialization %.3f" % best_data['score']) + if test_fn: + msglogger.info('Evaluating on full test set...') + l_top1, l_top5, l_loss = test_fn(quantizer.model) + msglogger.info('Test: \tloss=%.3f, top1=%.3f, top5=%.3f ' % (l_loss, l_top1, l_top5)) + + init_qp_dict = OrderedDict(quantizer.named_linear_quant_params(args.search_clipping, filter=True)) + keys, init_qp_vec = quant_params_dict2vec(init_qp_dict, args.search_clipping) + + iter_counter = count(1) + eval_counter = count(1) + + def feed_forward_fn(qp_vec): + # if not _check_qp_vec(keys, qp_vec, quant_mode, args.search_clipping): + # return 1e6 + qp_dict = quant_params_vec2dict(keys, qp_vec, args.search_clipping) + quantizer.update_linear_quant_params(qp_dict) + loss = eval_fn(quantizer.model) + + i = next(eval_counter) + if i % 20 == 0: + msglogger.info('%d evaluations: loss=%.3f' % (i, loss)) + + return loss + + def callback(qp_vec): + score = feed_forward_fn(qp_vec) + i = next(iter_counter) + msglogger.info("Iteration %d: \t Score=%.3f" % (i, score)) + if score < best_data['score']: + best_data['score'] = score + best_data['qp_dict'] = quant_params_vec2dict(keys, qp_vec, args.search_clipping) + msglogger.info("Saving current best quantization parameters.") + if test_fn: + msglogger.info('Evaluating on full test set...') + l_top1, l_top5, l_loss = test_fn(quantizer.model) + msglogger.info('Test: \tloss=%.3f, top1=%.3f, top5=%.3f ' % (l_loss, l_top1, l_top5)) + + options = options or OrderedDict() + if args.maxiter is not None: + options['maxiter'] = args.maxiter + if args.maxfev is not None: + options['maxfev'] = args.maxfev + minimizer_kwargs = minimizer_kwargs or OrderedDict() + minimizer_kwargs.update({ + 'method': method, 'options': options + }) + basinhopping = basinhopping or args.basinhopping + if basinhopping: + msglogger.info('Using basinhopping global minimum search with "%s" local minimization method'% + method) + res = opt.basinhopping(feed_forward_fn, init_qp_vec, args.niter, callback=callback, + minimizer_kwargs=minimizer_kwargs) + else: + msglogger.info('Using "%s" minimization algorithm.' % method) + res = opt.minimize(feed_forward_fn, init_qp_vec, callback=callback, **minimizer_kwargs) + + msglogger.info("Optimization done. Best configuration: %s" % best_data['qp_dict']) + return model, best_data['qp_dict'] + + +if __name__ == "__main__": + args = get_default_args() + args.epochs = float('inf') # hack for args parsing so there's no error in epochs + cc = classifier.ClassifierCompressor(args, script_dir=os.path.dirname(__file__)) + + args = deepcopy(cc.args) + + effective_test_size_bak = args.effective_test_size + args.effective_test_size = args.opt_val_size + eval_data_loader = classifier.load_data(args, load_train=False, load_val=False, load_test=True, fixed_subset=True) + + args.effective_test_size = effective_test_size_bak + test_data_loader = classifier.load_data(args, load_train=False, load_val=False, load_test=True) + + # logging + logging.getLogger().setLevel(logging.WARNING) + msglogger = logging.getLogger(__name__) + msglogger.setLevel(logging.INFO) + + model = cc.model.eval() + device = next(model.parameters()).device + + if args.memoize_dataloader: + memoized_data_loader = [] + for images, targets in eval_data_loader: + batch = images.to(device), targets.to(device) + memoized_data_loader.append(batch) + else: + memoized_data_loader = None + + def eval_fn(model): + if args.memoize_dataloader: + loss = 0 + for images, targets in memoized_data_loader: + outputs = model(images) + loss += cc.criterion(outputs, targets).item() + loss = loss / len(memoized_data_loader) + else: + _, _, loss = classifier.test(eval_data_loader, model, cc.criterion, [cc.tflogger, cc.pylogger], + None, args) + return loss + + def test_fn(model): + return classifier.test(test_data_loader, model, cc.criterion, [cc.tflogger, cc.pylogger], None, args) + + args.device = device + if args.resumed_checkpoint_path: + args.load_model_path = args.resumed_checkpoint_path + if args.load_model_path: + msglogger.info("Loading checkpoint from %s" % args.load_model_path) + model = apputils.load_lean_checkpoint(model, args.load_model_path, + model_device=args.device) + + if args.qe_stats_file: + msglogger.info("Loading stats from %s" % args.qe_stats_file) + with open(args.qe_stats_file, 'r') as f: + act_stats = distiller.yaml_ordered_load(f) + else: + act_stats = None + + dummy_input = torch.rand(*model.input_shape, device=args.device) + model, qp_dict = ptq_coordinate_search(model, dummy_input, eval_fn, args.method, + args=args, act_stats=act_stats, test_fn=test_fn) + + top1, top5, loss = test_fn(model) + + msglogger.info("Arch: %s \tTest: \t top1 = %.3f \t top5 = %.3f \t loss = %.3f" % + (args.arch, top1, top5, loss)) + distiller.yaml_ordered_save('%s.quant_params_dict.yaml' % args.arch, qp_dict) + + distiller.apputils.save_checkpoint(0, args.arch, model, extras={'top1': top1, 'qp_dict': qp_dict}, name=args.name, + dir=cc.logdir) diff --git a/distiller/quantization/quantizer.py b/distiller/quantization/quantizer.py index 6f4943c..1b33530 100644 --- a/distiller/quantization/quantizer.py +++ b/distiller/quantization/quantizer.py @@ -23,6 +23,7 @@ import torch.nn as nn import distiller import warnings from typing import Callable, Optional +from copy import deepcopy msglogger = logging.getLogger() @@ -187,6 +188,7 @@ class Quantizer(object): # A dictionary of replaced modules and their respective names. self.modules_processed = OrderedDict() + self.modules_processed_args = OrderedDict() def _add_qbits_entry(self, module_name, module_type, qbits): if module_type not in [nn.Conv2d, nn.Conv3d, nn.Linear, nn.Embedding]: @@ -314,6 +316,9 @@ class Quantizer(object): replace_msg(full_name, (module, new_module)) # Add to history of prepared submodules self.modules_processed[module] = full_name, new_module + # To allow recreating this wrapper later on + valid_args = full_name, deepcopy(self.module_qbits_map) + self.modules_processed_args[full_name] = valid_args, valid_kwargs setattr(container, name, new_module) # If a "leaf" module was replaced by a container, add the new layers to the QBits mapping diff --git a/distiller/quantization/range_linear.py b/distiller/quantization/range_linear.py index 287776b..5e3c7f8 100644 --- a/distiller/quantization/range_linear.py +++ b/distiller/quantization/range_linear.py @@ -16,6 +16,7 @@ import torch.nn as nn import torch.nn.functional as f +import numpy as np import argparse from enum import Enum from collections import OrderedDict, namedtuple @@ -124,18 +125,22 @@ def _get_saturation_fn(quant_mode, clip_mode, num_stds, num_bits=None): def _get_quant_params_from_tensor(tensor, num_bits, mode, clip=ClipMode.NONE, per_channel=False, num_stds=None, half_range=False, scale_approx_mult_bits=None): if per_channel and tensor.dim() not in [2, 4]: - raise ValueError('Per channel quantization possible only with 2D or 4D tensors (linear or conv layer weights)') + raise UnsatisfiedRequirements('Per channel quantization possible only with ' + '2D or 4D tensors (linear or conv layer weights)') if clip == ClipMode.N_STD: if per_channel: raise ValueError('N_STD clipping not supported with per-channel quantization') if num_stds is None: - raise ValueError('Clip mode set top N_STD but \'num_stds\' parameter not provided') + raise UnsatisfiedRequirements('Clip mode set top N_STD but \'num_stds\' parameter not provided') dim = 0 if clip == ClipMode.AVG or per_channel else None sat_fn = _get_saturation_fn(mode, clip, num_stds, num_bits) if is_linear_quant_mode_symmetric(mode): sat_val = sat_fn(tensor, dim) + if isinstance(sat_val, tuple): + assert len(sat_val) == 2 + sat_val = torch.max(*sat_val) scale, zp = symmetric_linear_quantization_params(num_bits, sat_val, restrict_qrange=mode == LinearQuantMode.SYMMETRIC_RESTRICTED) else: # Asymmetric mode @@ -193,6 +198,25 @@ def _get_quant_params_from_stats_dict(stats, num_bits, mode, clip=ClipMode.NONE, return scale, zp +def _get_clipping_values(scale, zp, num_bits, mode): + """ + Gets the saturation values induced by quantization values + Args: + scale, zp (torch.Tensor or float): quantization params + num_bits (int): number of bits + mode (LinearQuantMode): mode of quantization + Returns: + min, max : tuple[float, float] + """ + device = scale.device if isinstance(scale, torch.Tensor) else 'cpu' + if is_linear_quant_mode_asymmetric(mode): + t = torch.tensor([0, 2**num_bits-1], device=device) + else: + t = torch.tensor([-2**(num_bits-1), 2**(num_bits-1)-1], device=device) + sat_min, sat_max = linear_dequantize(t, scale, zp) # type: torch.Tensor + return sat_min, sat_max + + ############################################################################### # Post Training ############################################################################### @@ -218,8 +242,7 @@ class QuantSettings(object): return '(num_bits={} ; quant_mode={} ; clip_mode={} ; clip_n_stds={} ; clip_half_range={}' \ ' ; per_channel={})'.format(self.num_bits, _enum_to_str(self.quant_mode), _enum_to_str(self.clip_mode), self.clip_n_stds, self.clip_half_range, - self.per_channel - ) + self.per_channel) def linear_quantize_clamp_with_metadata(t, inplace=False): @@ -294,11 +317,15 @@ def add_post_train_quant_args(argparser): group.add_argument('--qe-no-clip-layers', '--qencl', type=str, nargs='+', metavar='LAYER_NAME', default=[], help='List of layer names for which not to clip activations. Applicable ' 'only if --qe-clip-acts is not \'none\'') + group.add_argument('--qe-no-quant-layers', '--qenql', type=str, nargs='+', metavar='LAYER_NAME', default=[], + help='List of layer names for which to skip quantization.') group.add_argument('--qe-per-channel', '--qepc', action='store_true', help='Enable per-channel quantization of weights (per output channel)') group.add_argument('--qe-scale-approx-bits', '--qesab', type=int, metavar='NUM_BITS', help='Enables scale factor approximation using integer multiply + bit shift, using ' 'this number of bits the integer multiplier') + group.add_argument('--qe-save-fp-weights', action='store_true', + help='Allow weights requantization.') stats_group = group.add_mutually_exclusive_group() stats_group.add_argument('--qe-stats-file', type=str, metavar='PATH', @@ -313,6 +340,25 @@ def add_post_train_quant_args(argparser): 'all other --qe* arguments are ignored)') +class UnsatisfiedRequirements(Exception): + pass + + +def _check_clipping_val(val, quant_mode, half_range): + if isinstance(val, float): + if is_linear_quant_mode_symmetric(quant_mode): + return -val, val + elif half_range: + return 0, val + raise ValueError('For asymmetric quantization, setting clipping values only allowed ' + 'using both min/max values.') + if isinstance(val, (tuple, list, np.ndarray, torch.Tensor)): + assert all(distiller.is_scalar(v) for v in val), 'Elements of the clipping value must be scalar-like.' + assert len(val) == 2, 'Clipping value must have 2 elements.' + return tuple(val) + raise TypeError('Clipping value should be a scalar or an iterable of these') + + class RangeLinearQuantWrapper(nn.Module): """ Base class for module which wraps an existing module with linear range-base quantization functionality @@ -356,6 +402,11 @@ class RangeLinearQuantWrapper(nn.Module): self.preset_act_stats = False self.register_buffer('num_forwards', torch.zeros(1, dtype=torch.long)) + self.register_buffer('force_readjust', torch.tensor(False)) + + # The accumulator is always signed + self.accum_min_q_val, self.accum_max_q_val = get_quantized_range(num_bits_accum, signed=True, + signed_restrict_qrange=False) # Activations not quantized - stop here if num_bits_acts is None: @@ -391,9 +442,6 @@ class RangeLinearQuantWrapper(nn.Module): restrict_qrange = mode.activations == LinearQuantMode.SYMMETRIC_RESTRICTED self.acts_min_q_val, self.acts_max_q_val = get_quantized_range(num_bits_acts, signed=signed, signed_restrict_qrange=restrict_qrange) - # The accumulator is always signed - self.accum_min_q_val, self.accum_max_q_val = get_quantized_range(num_bits_accum, signed=True, - signed_restrict_qrange=False) if activation_stats: self.preset_act_stats = True @@ -417,16 +465,73 @@ class RangeLinearQuantWrapper(nn.Module): scale, zp = _get_quant_params_from_stats_dict(activation_stats['output'], num_bits_acts, mode.activations, clip_acts, clip_n_stds, clip_half_range, scale_approx_mult_bits) + if not isinstance(scale, torch.Tensor): + scale, zp = torch.tensor(scale), torch.tensor(zp) self.register_buffer('output_scale', scale) self.register_buffer('output_zero_point', zp) else: self.preset_act_stats = False - def named_acts_quant_params(self): + def named_linear_quant_params(self, filter=False): if self.output_quant_settings.num_bits is not None and self.preset_act_stats: # Output scale buffers are saved in the model only when stats are used yield 'output_scale', self.output_scale - yield 'output_zero_point', self.output_zero_point + if not filter or (is_linear_quant_mode_asymmetric(self.mode.activations) and not self.clip_half_range): + yield 'output_zero_point', self.output_zero_point + + def set_linear_quant_param(self, name, val): + if name in dict(self.named_clipping()): + setattr(self, name, val) + elif name not in dict(self.named_linear_quant_params()): + raise ValueError('%s is not a quantization parameter.' % name) + else: + getattr(self, name).data.fill_(val) + self.force_readjust.fill_(True) + + def _check_requirements_output_clipping(self): + if not self.output_quant_settings.num_bits: + raise UnsatisfiedRequirements('Cannot retrieve clipping values because ' + 'the activations aren\'t quantized.') + if not self.preset_act_stats: + raise UnsatisfiedRequirements('Cannot retrieve clipping values ' + 'because the activations stats were not provided.') + + @property + def output_clipping(self): + self._check_requirements_output_clipping() + bits = self.output_quant_settings.num_bits + scale, zp = self.output_scale, self.output_zero_point + return _get_clipping_values(scale, zp, bits, self.output_quant_settings.quant_mode) + + @output_clipping.setter + def output_clipping(self, val): + """ + Args: + val (float or tuple[float, float] or tuple[torch.Tensor, torch.Tensor]): the value to set + """ + self._check_requirements_output_clipping() + qset = self.output_quant_settings + val_min, val_max = _check_clipping_val(val, qset.quant_mode, self.clip_half_range) + qset.clip_mode, qset.clip_half_range, qset.clip_n_stds = ClipMode.NONE, None, None + scale, zp = _get_quant_params_from_stats_dict({'min': val_min, 'max': val_max}, qset.num_bits, qset.quant_mode, + scale_approx_mult_bits=self.scale_approx_mult_bits) + self.set_linear_quant_param('output_scale', scale.item()) + self.set_linear_quant_param('output_zero_point', zp.item()) + + def named_clipping(self, filter=False): + val = self.output_clipping + if filter and (is_linear_quant_mode_symmetric(self.mode.activations) or self.clip_half_range): + val = val[1] + yield 'output_clipping', val + + def update_linear_quant_params(self, new_config): + """ + Updates all the quant params using a dictionary. + Args: + new_config (dict): the new configuration dict. + """ + for name, val in new_config.items(): + self.set_linear_quant_param(name, val) def forward(self, *inputs): if self.training: @@ -617,8 +722,8 @@ class RangeLinearQuantParamLayerWrapper(RangeLinearQuantWrapper): """ def __init__(self, wrapped_module, num_bits_acts, num_bits_params, num_bits_accum=32, mode=LinearQuantMode.SYMMETRIC, clip_acts=ClipMode.NONE, per_channel_wts=False, activation_stats=None, - clip_n_stds=None, clip_half_range=False, scale_approx_mult_bits=None, - input_overrides=None, inputs_quant_auto_fallback=False): + clip_n_stds=None, clip_half_range=False, scale_approx_mult_bits=None, input_overrides=None, + inputs_quant_auto_fallback=False, save_fp_weights=False, also_clip_weights=False): super(RangeLinearQuantParamLayerWrapper, self).__init__(wrapped_module, num_bits_acts, num_bits_accum, mode, clip_acts, activation_stats, clip_n_stds, clip_half_range, scale_approx_mult_bits, @@ -631,20 +736,32 @@ class RangeLinearQuantParamLayerWrapper(RangeLinearQuantWrapper): # If activations are not quantized, we do fake quantization of the parameters, that is - quant and de-quant self.fake_quant_params = self.output_quant_settings.num_bits is None - self.wts_quant_settings = QuantSettings(num_bits_params, self.mode.weights, ClipMode.NONE, None, False, - per_channel_wts) + clip_wts_mode, clip_wts_n_stds = ClipMode.NONE, None + if also_clip_weights: + clip_wts_mode = self.output_quant_settings.clip_mode + clip_wts_n_stds = self.output_quant_settings.clip_n_stds + self.wts_quant_settings = QuantSettings(num_bits_params, self.mode.weights, clip_wts_mode, clip_wts_n_stds, + False, per_channel_wts) self.params_min_q_val, self.params_max_q_val = get_quantized_range( self.wts_quant_settings.num_bits, self.wts_quant_settings.quant_mode != LinearQuantMode.ASYMMETRIC_UNSIGNED, self.wts_quant_settings.quant_mode == LinearQuantMode.SYMMETRIC_RESTRICTED ) + self.save_fp_weights = save_fp_weights + # save the float weight to allow re-quantizing + if save_fp_weights: + wrapped_module.register_buffer('float_weight', wrapped_module.weight.clone().detach()) # Quantize weights - overwrite FP32 weights w_scale, w_zero_point = _get_quant_params_from_tensor(wrapped_module.weight, self.wts_quant_settings.num_bits, self.wts_quant_settings.quant_mode, - per_channel=self.wts_quant_settings.per_channel) + clip=self.wts_quant_settings.clip_mode, + per_channel=self.wts_quant_settings.per_channel, + num_stds=self.wts_quant_settings.clip_n_stds) + w_scale = w_scale if isinstance(w_scale, torch.Tensor) else torch.tensor(w_scale) + w_zero_point = w_zero_point if isinstance(w_zero_point, torch.Tensor) else torch.tensor(w_zero_point) self.register_buffer('w_scale', w_scale) self.register_buffer('w_zero_point', w_zero_point) @@ -655,17 +772,21 @@ class RangeLinearQuantParamLayerWrapper(RangeLinearQuantWrapper): self.has_bias = hasattr(wrapped_module, 'bias') and wrapped_module.bias is not None if self.has_bias and (self.fake_quant_params or not self.preset_act_stats): b_scale, b_zero_point = _get_quant_params_from_tensor(wrapped_module.bias, - self.wts_quant_settings.num_bits, - self.wts_quant_settings.quant_mode) + self.accum_quant_settings.num_bits, + self.accum_quant_settings.quant_mode) self.register_buffer('b_scale', b_scale) self.register_buffer('b_zero_point', b_zero_point) base_b_q = linear_quantize_clamp(wrapped_module.bias.data, self.b_scale, self.b_zero_point, - self.params_min_q_val, self.params_max_q_val) + self.accum_min_q_val, self.accum_max_q_val) if not self.preset_act_stats: # Dynamic ranges - save in auxiliary buffer, # requantize each time based on dynamic input scale factor self.register_buffer('base_b_q', base_b_q) + # allow requantizing the bias: + if self.has_bias and self.preset_act_stats: + self.register_buffer('fp_bias', self.wrapped_module.bias.data.clone().detach()) + # Activations not quantized - de-quant parameters and return if self.fake_quant_params: linear_dequantize(wrapped_module.weight.data, self.w_scale, self.w_zero_point, inplace=True) @@ -676,7 +797,7 @@ class RangeLinearQuantParamLayerWrapper(RangeLinearQuantWrapper): # Activations are quantized - setup accumulator quantization parameters device = self.w_scale.device if self.preset_act_stats: - t = torch.zeros_like(self.w_scale) + t = torch.empty_like(self.w_scale) if self.wts_quant_settings.per_channel: t = t.squeeze(dim=-1) self.register_buffer('accum_scale', t) @@ -688,7 +809,70 @@ class RangeLinearQuantParamLayerWrapper(RangeLinearQuantWrapper): # and all subsequent calls are done with these shifted weights. # Upon calling `self.state_dict()` - we restore the actual quantized weights. # i.e. is_simulated_quant_weight_shifted = False - self.register_buffer('is_simulated_quant_weight_shifted', torch.tensor(0, dtype=torch.uint8, device=device)) + self.register_buffer('is_simulated_quant_weight_shifted', torch.tensor(False, device=device)) + + def named_linear_quant_params(self, filter=False): + if self.save_fp_weights: + yield 'w_scale', self.w_scale + if not filter or is_linear_quant_mode_asymmetric(self.mode.weights): + yield 'w_zero_point', self.w_zero_point + yield from super(RangeLinearQuantParamLayerWrapper, self).named_linear_quant_params(filter=filter) + + def set_linear_quant_param(self, name, val): + if name in ['w_scale', 'w_zero_point']: + if self.save_fp_weights: + super().set_linear_quant_param(name, val) + self.wrapped_module.weight.data.copy_(self.wrapped_module.float_weight.data) + linear_quantize_clamp(self.wrapped_module.weight.data, self.w_scale, self.w_zero_point, + self.params_min_q_val, + self.params_max_q_val, inplace=True) + if self.fake_quant_params: + linear_dequantize(self.wrapped_module.weight.data, self.w_scale, self.w_zero_point, inplace=True) + + else: + raise UnsatisfiedRequirements('Cannot re-quantize the weights. Please specify \'save_fp_weights\' in ' + 'the %s constructor to enable re-quantizing the weights.' % + self.__class__.__name__) + else: + super().set_linear_quant_param(name, val) + + def _check_requirements_weights_clipping(self, setter=False): + if not self.wts_quant_settings.num_bits: + raise UnsatisfiedRequirements('Cannot retrieve clipping values because the weights aren\'t quantized.') + if setter and not self.save_fp_weights: + warnings.warn('Without saving fp32 version of weights, re-quantization is disabled. To enable, ' + 'please set \'save_fp_weights\' while constructing the wrapper.') + + @property + def weight_clipping(self): + self._check_requirements_weights_clipping(setter=False) + bits, mode = self.wts_quant_settings.num_bits, self.wts_quant_settings.quant_mode + scale, zp = self.w_scale, self.w_zero_point + return _get_clipping_values(scale, zp, bits, mode) + + @weight_clipping.setter + def weight_clipping(self, val): + self._check_requirements_weights_clipping(setter=True) + bits = self.wts_quant_settings.num_bits + val_min, val_max = _check_clipping_val(val, self.wts_quant_settings.quant_mode, False) + if is_linear_quant_mode_symmetric(self.wts_quant_settings.quant_mode): + # in symmetric quantization - we only need one value + scale, zp = symmetric_linear_quantization_params(bits, abs(max(val_min, val_max))) + else: + signed = self.wts_quant_settings.quant_mode == LinearQuantMode.ASYMMETRIC_SIGNED + scale, zp = asymmetric_linear_quantization_params(bits, val_min, val_max, signed=signed) + self.set_linear_quant_param('w_scale', scale) + self.set_linear_quant_param('w_zero_point', zp) + + def named_clipping(self, filter=False): + try: + yield from super().named_clipping(filter=filter) + except UnsatisfiedRequirements as ex: + warnings.warn(str(ex)) # probably the output isn't quantized + val = self.weight_clipping + if filter and is_linear_quant_mode_symmetric(self.mode.weights): + val = val[1] + yield 'weight_clipping', val def state_dict(self, destination=None, prefix='', keep_vars=False): if not self.fake_quant_params and self.is_simulated_quant_weight_shifted: @@ -709,12 +893,15 @@ class RangeLinearQuantParamLayerWrapper(RangeLinearQuantWrapper): return accum_scale if self.preset_act_stats: - if self.num_forwards == 0: - self.accum_scale += get_accum_scale(input_q) + if self.num_forwards == 0 or self.force_readjust: + self.accum_scale.copy_(get_accum_scale(input_q)) if self.has_bias: # Requantize bias to accumulator scale "permanently" - linear_quantize_clamp(self.wrapped_module.bias.data, self.accum_scale.squeeze(), 0, - self.accum_min_q_val, self.accum_max_q_val, inplace=True) + self.wrapped_module.bias.data.copy_( + linear_quantize_clamp(self.fp_bias, self.accum_scale.squeeze(), 0, + self.accum_min_q_val, self.accum_max_q_val) + ) + self.force_readjust.fill_(False) else: self.accum_scale = get_accum_scale(input_q) if self.has_bias: @@ -997,7 +1184,7 @@ class FP16Wrapper(FPWrapper): class RangeLinearEmbeddingWrapper(nn.Module): - def __init__(self, wrapped_module, num_bits, mode=LinearQuantMode.SYMMETRIC, stats=None): + def __init__(self, wrapped_module, num_bits, mode=LinearQuantMode.SYMMETRIC, stats=None, save_fp_weights=False): if not isinstance(wrapped_module, nn.Embedding): raise ValueError(self.__class__.__name__ + ' can only wrap torch.nn.Embedding modules') @@ -1006,9 +1193,16 @@ class RangeLinearEmbeddingWrapper(nn.Module): mode = verify_quant_mode(mode) self.mode = mode - self.min_q_val, self.max_q_val = get_quantized_range( - num_bits, signed=mode.weights != LinearQuantMode.ASYMMETRIC_UNSIGNED, - signed_restrict_qrange=mode.weights == LinearQuantMode.SYMMETRIC_RESTRICTED) + self.wts_quant_settings = QuantSettings(num_bits, self.mode.weights, ClipMode.NONE, None, False, False) + + self.params_min_q_val, self.params_max_q_val = get_quantized_range( + self.wts_quant_settings.num_bits, + self.wts_quant_settings.quant_mode != LinearQuantMode.ASYMMETRIC_UNSIGNED, + self.wts_quant_settings.quant_mode == LinearQuantMode.SYMMETRIC_RESTRICTED + ) + self.save_fp_weights = save_fp_weights + if save_fp_weights: + wrapped_module.register_buffer('float_weight', wrapped_module.weight.clone().detach()) if stats is None: w_scale, w_zero_point = _get_quant_params_from_tensor(wrapped_module.weight, num_bits, mode.weights) @@ -1016,17 +1210,76 @@ class RangeLinearEmbeddingWrapper(nn.Module): w_scale, w_zero_point = _get_quant_params_from_stats_dict(stats['output'], num_bits, mode.weights) device = wrapped_module.weight.device - self.register_buffer('w_scale', w_scale.to(device)) self.register_buffer('w_zero_point', w_zero_point.to(device)) - linear_quantize_clamp(wrapped_module.weight.data, self.w_scale, self.w_zero_point, self.min_q_val, - self.max_q_val, inplace=True) - self.quant_metadata = TensorQuantMetadata(self.w_scale, self.w_zero_point, self.min_q_val, self.max_q_val) + linear_quantize_clamp(wrapped_module.weight.data, self.w_scale, self.w_zero_point, + self.params_min_q_val, self.params_max_q_val, inplace=True) + self.quant_metadata = TensorQuantMetadata(self.w_scale, self.w_zero_point, + self.params_min_q_val, self.params_max_q_val) self.wrapped_module = wrapped_module - def named_acts_quant_params(self): + def named_linear_quant_params(self, filter=False): yield 'w_scale', self.w_scale - yield 'w_zero_point', self.w_zero_point + if not filter or is_linear_quant_mode_asymmetric(self.mode.weights): + yield 'w_zero_point', self.w_zero_point + + def set_linear_quant_param(self, name, val): + if name in ['w_scale', 'w_zero_point']: + if self.save_fp_weights: + getattr(self, name).fill_(val) + self.wrapped_module.weight.data.copy_(self.wrapped_module.float_weight.data) + linear_quantize_clamp(self.wrapped_module.weight.data, self.w_scale, self.w_zero_point, + self.params_min_q_val, + self.params_max_q_val, inplace=True) + else: + raise UnsatisfiedRequirements('Cannot re-quantize the weights. Please specify \'save_fp_weights\' in ' + 'the %s constructor to enable re-quantizing the weights.' % + self.__class__.__name__) + else: + raise KeyError('No quantization parameter called \'%s\'.' % name) + + def update_linear_quant_params(self, new_config): + """ + Updates all the quant params using a dictionary. + Args: + new_config (dict): the new configuration dict. + """ + for name, val in new_config.items(): + self.set_linear_quant_param(name, val) + + def _check_requirements_weights_clipping(self, setter=False): + if not self.wts_quant_settings.num_bits: + raise UnsatisfiedRequirements('Cannot retrieve clipping values because the weights aren\'t quantized.') + if setter and not self.save_fp_weights: + warnings.warn('Without saving fp32 version of weights, re-quantization is disabled. To enable, ' + 'please set \'save_fp_weights\' while constructing the wrapper.') + + @property + def weight_clipping(self): + self._check_requirements_weights_clipping(setter=False) + bits, mode = self.wts_quant_settings.num_bits, self.wts_quant_settings.quant_mode + scale, zp = self.w_scale, self.w_zero_point + return _get_clipping_values(scale, zp, bits, mode) + + @weight_clipping.setter + def weight_clipping(self, val): + self._check_requirements_weights_clipping(setter=True) + bits = self.wts_quant_settings.num_bits + val_min, val_max = _check_clipping_val(val, self.wts_quant_settings.quant_mode, False) + if is_linear_quant_mode_symmetric(self.wts_quant_settings.quant_mode): + # in symmetric quantization - we only need one value + scale, zp = symmetric_linear_quantization_params(bits, val_max) + else: + signed = self.wts_quant_settings.quant_mode == LinearQuantMode.ASYMMETRIC_SIGNED + scale, zp = asymmetric_linear_quantization_params(bits, val_min, val_max, signed=signed) + self.set_linear_quant_param('w_scale', scale) + self.set_linear_quant_param('w_zero_point', zp) + + def named_clipping(self, filter=False): + val = self.weight_clipping + if filter and is_linear_quant_mode_symmetric(self.mode.weights): + val = val[1] + yield 'weight_clipping', val def forward(self, input): out_q = self.wrapped_module(input) @@ -1038,7 +1291,14 @@ class RangeLinearEmbeddingWrapper(nn.Module): class RangeLinearFakeQuantWrapper(RangeLinearQuantWrapper): def __init__(self, wrapped_module, num_bits_acts, mode=LinearQuantMode.SYMMETRIC, clip_acts=ClipMode.NONE, activation_stats=None, clip_n_stds=None, clip_half_range=False, scale_approx_mult_bits=None, - fpq_module=None, input_overrides=None, inputs_quant_auto_fallback=False): + fpq_module=None, input_overrides=None, inputs_quant_auto_fallback=False, quantize_inputs=False): + if isinstance(wrapped_module, (nn.ReLU, nn.ReLU6)): + # In case of ReLU + Gauss/Laplace clipping, need to clip according to stats before ReLU is applied + clip_half_range = True + if clip_acts in (ClipMode.GAUSS, ClipMode.LAPLACE): + activation_stats['output']['mean'] = activation_stats['inputs'][0]['mean'] + activation_stats['output']['std'] = activation_stats['inputs'][0]['std'] + activation_stats['output']['b'] = activation_stats['inputs'][0]['b'] super(RangeLinearFakeQuantWrapper, self).__init__(wrapped_module, num_bits_acts, mode=mode, clip_acts=clip_acts, activation_stats=activation_stats, clip_n_stds=clip_n_stds, clip_half_range=clip_half_range, @@ -1047,11 +1307,15 @@ class RangeLinearFakeQuantWrapper(RangeLinearQuantWrapper): inputs_quant_auto_fallback=inputs_quant_auto_fallback) self.fpq_module = str(fpq_module) if fpq_module else None self.dtype = torch.float + self.quantize_inputs = quantize_inputs if self.fpq_module: self.dtype = {'16': torch.half, '32': torch.float, '64': torch.double}[self.fpq_module] self.wrapped_module.to(self.dtype) def _prepare_input(self, idx, input): + if not self.quantize_inputs: + return input + previously_quantized = hasattr(input, 'quant_metadata') input.quant_metadata = self._get_input_quant_metadata(idx, input) if previously_quantized: @@ -1143,6 +1407,9 @@ class PostTrainLinearQuantizer(Quantizer): For details what this does and how to override it. fpq_module (Union[int, str]): use the modules in floating point mode and only quantize their outputs. takes the values (16, 32, 64) only, this will use RangeLinearFakeQuantWrapper. + save_fp_weights (bool): Indicates whether or not to save a copy of the floating point weights. + This allows re-quantization of weight after the initial quantization. + Defaults to False for performance. Note: If fpq_module is set, all the layers (except those overridden in `overrides`) will be converted to the set floating point precision, regardless of bits_activations/parameters/accum. @@ -1152,7 +1419,7 @@ class PostTrainLinearQuantizer(Quantizer): per_channel_wts=False, model_activation_stats=None, fp16=False, clip_n_stds=None, clip_half_range=False, scale_approx_mult_bits=None, inputs_quant_auto_fallback=True, - fpq_module=None): + fpq_module=None, save_fp_weights=False, also_clip_weights=False): overrides_bkp = deepcopy(overrides) super(PostTrainLinearQuantizer, self).__init__(model, bits_activations=bits_activations, bits_weights=bits_parameters, bits_bias=bits_accum, @@ -1208,7 +1475,7 @@ class PostTrainLinearQuantizer(Quantizer): def replace_param_layer(module, name, qbits_map, per_channel_wts=per_channel_wts, mode=mode, fp16=fp16, scale_approx_mult_bits=scale_approx_mult_bits, - clip_acts=clip_acts, clip_n_stds=clip_n_stds, clip_half_range=clip_half_range, + clip_acts=None, clip_n_stds=clip_n_stds, clip_half_range=clip_half_range, input_overrides=None, fpq_module=fpq_module, fake=False): fpq_module = _check_fp16_arg(fp16, fpq_module) if fpq_module and not fake: @@ -1216,7 +1483,7 @@ class PostTrainLinearQuantizer(Quantizer): norm_name = distiller.utils.normalize_module_name(name) activation_stats = self.model_activation_stats.get(norm_name, None) - clip_acts = verify_clip_mode(clip_acts) + clip_acts = verify_clip_mode(clip_acts or self.clip_acts) qbits = qbits_map[name] if qbits.acts is not None and qbits.wts is None: # Quantizing only activations equals fake-quantization @@ -1228,7 +1495,8 @@ class PostTrainLinearQuantizer(Quantizer): clip_n_stds=clip_n_stds, clip_half_range=clip_half_range, scale_approx_mult_bits=scale_approx_mult_bits, fpq_module=fpq_module, input_overrides=input_overrides, - inputs_quant_auto_fallback=inputs_quant_auto_fallback) + inputs_quant_auto_fallback=inputs_quant_auto_fallback, + quantize_inputs=False) return RangeLinearQuantParamLayerWrapper(module, qbits.acts, qbits.wts, num_bits_accum=self.bits_accum, mode=mode, clip_acts=clip_acts, @@ -1237,11 +1505,13 @@ class PostTrainLinearQuantizer(Quantizer): clip_n_stds=clip_n_stds, clip_half_range=clip_half_range, scale_approx_mult_bits=scale_approx_mult_bits, input_overrides=input_overrides, - inputs_quant_auto_fallback=inputs_quant_auto_fallback) + inputs_quant_auto_fallback=inputs_quant_auto_fallback, + save_fp_weights=self.save_fp_weights, + also_clip_weights=self.also_clip_weights) def replace_non_param_layer(wrapper_type, module, name, qbits_map, fp16=fp16, scale_approx_mult_bits=scale_approx_mult_bits, - clip_acts=clip_acts, clip_n_stds=clip_n_stds, clip_half_range=clip_half_range, + clip_acts=None, clip_n_stds=clip_n_stds, clip_half_range=clip_half_range, input_overrides=None, inputs_quant_auto_fallback=inputs_quant_auto_fallback, fpq_module=fpq_module, fake=False): fpq_module = _check_fp16_arg(fp16, fpq_module) @@ -1250,7 +1520,7 @@ class PostTrainLinearQuantizer(Quantizer): norm_name = distiller.utils.normalize_module_name(name) activation_stats = self.model_activation_stats.get(norm_name, None) - clip_acts = verify_clip_mode(clip_acts) + clip_acts = verify_clip_mode(clip_acts or self.clip_acts) qbits = qbits_map[name] if fake: @@ -1259,7 +1529,8 @@ class PostTrainLinearQuantizer(Quantizer): clip_n_stds=clip_n_stds, clip_half_range=clip_half_range, scale_approx_mult_bits=scale_approx_mult_bits, fpq_module=fpq_module, input_overrides=input_overrides, - inputs_quant_auto_fallback=inputs_quant_auto_fallback) + inputs_quant_auto_fallback=inputs_quant_auto_fallback, + quantize_inputs=False) try: return wrapper_type(module, qbits.acts, mode=mode, clip_acts=clip_acts, activation_stats=activation_stats, @@ -1281,10 +1552,10 @@ class PostTrainLinearQuantizer(Quantizer): stats=self.model_activation_stats.get(norm_name, None)) def replace_fake_quant(module, name, qbits_map, fp16=fp16, - clip_acts=clip_acts, clip_n_stds=clip_n_stds, clip_half_range=clip_half_range, + clip_acts=None, clip_n_stds=clip_n_stds, clip_half_range=clip_half_range, scale_approx_mult_bits=scale_approx_mult_bits, input_overrides=None, inputs_quant_auto_fallback=inputs_quant_auto_fallback, - fpq_module=fpq_module, fake=True, make_identity=False): + fpq_module=fpq_module, fake=True, make_identity=False, quantize_inputs=True): if isinstance(module, (nn.ReLU, nn.ReLU6)) and make_identity: named_modules = OrderedDict(self.model.named_modules()) pred = self.adjacency_map[name].predecessors[0].name @@ -1299,19 +1570,22 @@ class PostTrainLinearQuantizer(Quantizer): return FPWrapper(module, fpq_module) norm_name = distiller.utils.normalize_module_name(name) - clip_acts = verify_clip_mode(clip_acts) + clip_acts = verify_clip_mode(clip_acts or self.clip_acts) return RangeLinearFakeQuantWrapper(module, qbits_map[name].acts, mode=mode, clip_acts=clip_acts, activation_stats=self.model_activation_stats.get(norm_name, None), clip_n_stds=clip_n_stds, clip_half_range=clip_half_range, scale_approx_mult_bits=scale_approx_mult_bits, fpq_module=fpq_module, input_overrides=input_overrides, - inputs_quant_auto_fallback=inputs_quant_auto_fallback) + inputs_quant_auto_fallback=inputs_quant_auto_fallback, + quantize_inputs=quantize_inputs) self.clip_acts = clip_acts self.clip_n_stds = clip_n_stds self.model_activation_stats = model_activation_stats or {} self.bits_accum = bits_accum self.mode = mode + self.save_fp_weights = save_fp_weights + self.also_clip_weights = also_clip_weights self.replacement_factory[nn.Conv2d] = replace_param_layer self.replacement_factory[nn.Conv3d] = replace_param_layer @@ -1341,34 +1615,90 @@ class PostTrainLinearQuantizer(Quantizer): self.default_repalcement_fn = replace_fake_quant self.replacement_blacklist.append(nn.Dropout) - save_dir = msglogger.logdir if hasattr(msglogger, 'logdir') else '.' - self.save_per_layer_parameters(save_dir) + # To be filled by .prepare_model() + self.linear_quant_params = None - def named_acts_quant_params(self): + def named_linear_quant_params(self, yield_clipping_params=False, filter=False): + if yield_clipping_params: + yield from self.named_clipping(filter=filter) + return for module_name, module in self.model.named_modules(): if is_post_train_quant_wrapper(module, include_fpwrapper=False): - for buff_name, buff in module.named_acts_quant_params(): + for buff_name, buff in module.named_linear_quant_params(filter=filter): full_buff_name = "%s.%s" % (module_name, buff_name) yield full_buff_name, buff - def set_act_quant_param(self, name, val): + def named_clipping(self, filter=False): + """ + Gets all the clipping parameters of the model. + yields tuple[str, tuple[torch.Tensor, torch.Tensor]] + """ + for module_name, module in self.model.named_modules(): + if not is_post_train_quant_wrapper(module, include_fpwrapper=False): + continue + for clip_name, clip_val in module.named_clipping(filter=filter): # type: str, tuple[torch.Tensor, torch.Tensor] + yield '%s.%s' % (module_name, clip_name), clip_val + + def set_clipping(self, name, val): + """ + Sets a clipping parameter by name. + Args: + name (str): the name of the clipping parameter. + val (tuple[float or torch.Tensor, float or torch.Tensor]): the value of the clipping. + """ + module_name = distiller.param_name_2_module_name(name) + clip_name = name.split('.')[-1] + module = dict(self.model.named_modules())[module_name] + if not is_post_train_quant_wrapper(module, False): + raise ValueError('\'%s\' isn\'t a wrapper and has no clipping parameters.' % module_name) + if clip_name not in dict(module.named_clipping()): + raise ValueError('\'%s\' is not a clipping parameter.' % clip_name) + setattr(module, clip_name, val) + + def update_clipping_parameters(self, clipping_config): + """ + Updates all clipping paramters according to a configuration dict. + Args: + clipping_config (dict[str, tuple[float or torch.Tensor, float or torch.Tensor]]): + the clipping configuration. + """ + for name, val in clipping_config.items(): + self.set_clipping(name, val) + + def _is_clipping_parameter(self, name): + module_name = distiller.param_name_2_module_name(name) + clip_name = name.split('.')[-1] + module = dict(self.model.named_modules())[module_name] + return is_post_train_quant_wrapper(module, False) and clip_name in dict(module.named_clipping()) + + def force_readjust_wrappers(self): + def _force_readjust(module): + if isinstance(module, RangeLinearQuantWrapper): + module.force_readjust.fill_(True) + self.model.apply(_force_readjust) + + def set_linear_quant_param(self, name, val): """ Sets the the quant parameter by module_name.quant_param_name. + Can also set the clipping values. Args: name (str): the name of the quant param [module_name].[quant_param_name] - val (int or float or torch.Tensor): the new value. + val: the new value. """ - self.acts_quant_params[name].fill_(val) + if self._is_clipping_parameter(name): + self.set_clipping(name, val) + else: + self.linear_quant_params[name].data.fill_(val) + self.force_readjust_wrappers() - def update_acts_quant_params(self, new_config): + def update_linear_quant_params(self, new_config): """ Updates all the quant params using a dictionary. Args: new_config (dict): the new configuration dict. """ for k, v in new_config.items(): - self.set_act_quant_param(k, v) - + self.set_linear_quant_param(k, v) @classmethod def from_args(cls, model, args): @@ -1385,9 +1715,15 @@ class PostTrainLinearQuantizer(Quantizer): if args.qe_bits_wts == 0: args.qe_bits_wts = None overrides = OrderedDict( - [(layer, OrderedDict([('clip_acts', 'NONE')])) - for layer in args.qe_no_clip_layers] + [ + (layer, OrderedDict([('bits_activations', None), ('bits_weights', None)])) + for layer in args.qe_no_quant_layers + ] ) + overrides.update(OrderedDict( + [(layer, OrderedDict([('clip_acts', 'NONE')])) + for layer in args.qe_no_clip_layers if layer not in args.qe_no_quant_layers] + )) mode_acts = args.qe_mode_acts or args.qe_mode mode_wts = args.qe_mode_wts or args.qe_mode mode = ModuleQuantMode(mode_acts, mode_wts) @@ -1402,7 +1738,8 @@ class PostTrainLinearQuantizer(Quantizer): clip_n_stds=args.qe_clip_n_stds, scale_approx_mult_bits=args.qe_scale_approx_bits, overrides=overrides, - inputs_quant_auto_fallback=True) + inputs_quant_auto_fallback=True, + save_fp_weights=args.qe_save_fp_weights) def save_per_layer_parameters(self, save_dir=''): defaults = OrderedDict(self.model.quantizer_metadata['params']) @@ -1410,8 +1747,10 @@ class PostTrainLinearQuantizer(Quantizer): defaults.pop('bits_parameters') defaults.pop('bits_accum') out = OrderedDict() - for n, m in self.model.named_modules(): - if distiller.has_children(m): + for n in self.module_overrides_map: + modules_dict = dict(self.model.named_modules()) + m = modules_dict[n] + if distiller.has_children(m) and not is_post_train_quant_wrapper(m, include_fpwrapper=False): continue qbits = self.module_qbits_map[n] d = OrderedDict() @@ -1422,6 +1761,11 @@ class PostTrainLinearQuantizer(Quantizer): actual_v = self.module_overrides_map[n].get(k, v) d[k] = actual_v out[n] = d + if self.linear_quant_params: + out['linear_quant_params'] = lqp_dict = OrderedDict() + for k, v in self.linear_quant_params.items(): # type: str, torch.Tensor + lqp_dict[k] = v.item() + save_path = os.path.join(save_dir, 'layer_quant_params.yaml') distiller.yaml_ordered_save(save_path, out) msglogger.info('Per-layer quantization parameters saved to ' + save_path) @@ -1438,10 +1782,12 @@ class PostTrainLinearQuantizer(Quantizer): # Setting dummy_input to None to make sure SummaryGraph won't be called dummy_input = None elif dummy_input is None: - raise ValueError('PostTrainLinearQuantizer requires dummy input in order to perform certain optimizations') + raise UnsatisfiedRequirements('PostTrainLinearQuantizer requires dummy ' + 'input in order to perform certain optimizations') super(PostTrainLinearQuantizer, self).prepare_model(dummy_input) - self.acts_quant_params = OrderedDict(self.named_acts_quant_params()) + save_dir = msglogger.logdir if hasattr(msglogger, 'logdir') else '.' + self.save_per_layer_parameters(save_dir) def _pre_prepare_model(self, dummy_input): if not self.has_bidi_distiller_lstm: @@ -1519,7 +1865,11 @@ class PostTrainLinearQuantizer(Quantizer): named_modules = OrderedDict(self.model.named_modules()) model_stats = self.model_activation_stats for n, m in named_modules.items(): - if (distiller.has_children(m) and not isinstance(m, SimulatedFoldedBatchNorm) )\ + # Don't fuse if module outputs aren't quantized: + qbits = self.module_qbits_map.get(n, QBits(None, None, None)) + if qbits.acts is None: + continue + if (distiller.has_children(m) and not isinstance(m, SimulatedFoldedBatchNorm))\ or n not in self.adjacency_map or len(self.adjacency_map[n].successors) != 1: continue successor = self.adjacency_map[n].successors[0] @@ -1570,12 +1920,11 @@ class PostTrainLinearQuantizer(Quantizer): def _apply_fuse_relu(self): """Fuses ReLU layers to the linear layers before them.""" model_overrides = self.module_overrides_map - qbits_map = self.module_qbits_map named_modules = dict(self.model.named_modules()) for n, m in named_modules.items(): - # Don't fuse if the module isn't quantized: - qbits = qbits_map.get(n, QBits(None, None, None)) - if qbits.acts is None and qbits.wts is None: + # Don't fuse if module outputs aren't quantized: + qbits = self.module_qbits_map.get(n, QBits(None, None, None)) + if qbits.acts is None: continue if (distiller.has_children(m) and not isinstance(m, SimulatedFoldedBatchNorm))\ or n not in self.adjacency_map or len(self.adjacency_map[n].successors) != 1: @@ -1602,14 +1951,13 @@ class PostTrainLinearQuantizer(Quantizer): self._clip_stats(self.model_activation_stats[name]['output'], -sat_val, sat_val) def _post_prepare_model(self): - if isinstance(self.model, nn.DataParallel): - # We restore the buffers to the master-GPU of the modules: - device = self.model.src_device_obj - m = self.model.module - for param in m.parameters(): - param.data = param.data.to(device) - for buffer in m.buffers(): - buffer.data = buffer.data.to(device) + m = self.model + device = distiller.model_device(m) + for param in m.parameters(): + param.data = param.data.to(device) + for buffer in m.buffers(): + buffer.data = buffer.data.to(device) + self.linear_quant_params = OrderedDict(self.named_linear_quant_params()) ############################################################################### diff --git a/distiller/utils.py b/distiller/utils.py index 3c5c6b3..2af8be9 100755 --- a/distiller/utils.py +++ b/distiller/utils.py @@ -40,6 +40,8 @@ msglogger = logging.getLogger() def model_device(model): """Determine the device the model is allocated on.""" # Source: https://discuss.pytorch.org/t/how-to-check-if-model-is-on-cuda/180 + if isinstance(model, nn.DataParallel): + return model.src_device_obj try: return str(next(model.parameters()).device) except StopIteration: @@ -785,4 +787,10 @@ def model_setattr(model, attr_name, val, register=False): def param_name_2_module_name(param_name): - return '.'.join(param_name.split('.')[:-1]) \ No newline at end of file + return '.'.join(param_name.split('.')[:-1]) + + +def is_scalar(val): + result = isinstance(val, torch.Tensor) and val.dim() == 0 + result |= np.isscalar(val) + return result diff --git a/examples/quantization/post_train_quant/resnet18_imagenet_post_train_lapq.yaml b/examples/quantization/post_train_quant/resnet18_imagenet_post_train_lapq.yaml new file mode 100644 index 0000000..4616d21 --- /dev/null +++ b/examples/quantization/post_train_quant/resnet18_imagenet_post_train_lapq.yaml @@ -0,0 +1,101 @@ +# Post-training quantization settings for running in the same conditions as in: +# Nahshan et al., "Loss Aware Post-training Quantization" (https://arxiv.org/abs/1911.07190), according to the +# reference implementation found at: https://github.com/ynahshan/nn-quantization-pytorch/tree/master/lapq +# +# The settings are: +# * Only fake-quantization is done +# * Only weights of convolutions and outputs of ReLU are quantized. Pooling, element-wise addition and FC layers +# are not quantized +# * The first convolution + relu pair isn't quantized +# * The last convolution + relu pair ("layer4.1.conv2/relu2") isn't quantized +# +# See example invocations and results after the YAML definition + +quantizers: + post_train_quantizer: + class: PostTrainLinearQuantizer + # Don't quantize anything by default, override below for should be quantized + bits_activations: null + bits_parameters: null + bits_accum: 32 + mode: + activations: ASYMMETRIC_UNSIGNED + weights: SYMMETRIC + model_activation_stats: ../../examples/quantization/post_train_quant/stats/resnet18_quant_stats.yaml + per_channel_wts: False + inputs_quant_auto_fallback: True + + overrides: + # Conv layers inside the ResNet BasicBlock are quantized (except last one, see below) + layer.*conv.*: + bits_activations: null + bits_weights: 4 + # ReLU layers inside the ResNet BasicBlock are quantized (except last one, see below) + layer.*relu.*: + bits_activations: 4 + quantize_inputs: False + # Conv layers in downsampling residual connections are quantized + .*downsample\.0.*: + bits_activations: null + bits_weights: 4 + # The last conv+relu layers are NOT quantized, we specify them directly + layer4.1.conv2: + bits_activations: null + bits_weights: null + layer4.1.relu2: + bits_activations: null + +# Example invocations: +# * Preliminaries: +# cd <distiller_root>/distiller/quantization +# CONFIG_FILE="<distiller_root>/examples/quantization/post_train_quant/resnet18_imagenet_post_train_lapq.yaml" +# +# * Using L3 initialization: +# Command: +# python ptq_coordinate_search.py -a resnet18 --pretrained <path_to_imagenet> --opt-val-size 0.01 --opt-maxiter 2 --qe-config-file $CONFIG_FILE -b 500 --opt-init-mode L3 --opt-init-method powell --opt-eval-memoize-dataloader --det --opt-search-clipping +# +# Excerpts from output: +# ... +# Initializing quantizer... +# Initializing quantization parameters... +# ... +# Evaluating initial quantization score... +# Evaluation set loss after initialization 2.522 +# Test: loss=2.650, top1=43.816, top5=68.840 +# Using "Powell" minimization algorithm. +# ... +# 980 evaluations: loss=1.962 +# Iteration 0: Score=1.956 +# Test: loss=2.117, top1=51.662, top5=76.606 +# ... +# 2200 evaluations: loss=1.929 +# Iteration 1: Score=1.926 +# Test: loss=2.116, top1=51.784, top5=76.712 +# Optimization Done. +# Arch: resnet18 Test: top1 = 51.784 top5 = 76.712 loss = 2.116 +# +# +# * Using LAPLACE initialization: +# Command: +# python ptq_coordinate_search.py -a resnet18 --pretrained <path_to_imagenet> --opt-val-size 0.01 --opt-maxiter 2 --qe-config-file $CONFIG_FILE -b 500 --opt-init-mode LAPLACE --opt-init-method powell --opt-eval-memoize-dataloader --det --opt-search-clipping +# +# Excerpts from output: +# ... +# Initializing quantizer... +# Initializing quantization parameters... +# ... +# Evaluating initial quantization score... +# Evaluation set loss after initialization 3.376 +# Evaluating on full test set... +# Test: loss=3.509, top1=29.492, top5=53.768 +# Using "Powell" minimization algorithm. +# ... +# 620 evaluations: loss=2.458 +# Iteration 0: Score=2.458 +# Test: loss=2.650, top1=42.700, top5=68.138 +# ... +# 1780 evaluations: loss=2.277 +# Iteration 1: Score=2.274 +# Test: loss=2.504, top1=45.164, top5=70.400 +# Optimization Done. +# Arch: resnet18 Test: top1 = 45.164 top5 = 70.400 loss = 2.504 diff --git a/tests/test_post_train_quant.py b/tests/test_post_train_quant.py index fbd5d14..95a0295 100644 --- a/tests/test_post_train_quant.py +++ b/tests/test_post_train_quant.py @@ -178,13 +178,13 @@ def linear_bias(): "mode, clip_acts, per_channel_wts, expected_output", [ (LinearQuantMode.ASYMMETRIC_UNSIGNED, ClipMode.NONE, False, - torch.tensor([[7.686200692, 0.241135708, 0.783691051]], dtype=torch.float32)), + torch.tensor([[7.687381776, 0.241172762, 0.783811475]], dtype=torch.float32)), (LinearQuantMode.ASYMMETRIC_UNSIGNED, ClipMode.NONE, True, - torch.tensor([[7.698823529, 0.241531719, 0.784978085]], dtype=torch.float32)), + torch.tensor([[7.699930796, 0.241566456, 0.785090983]], dtype=torch.float32)), (LinearQuantMode.SYMMETRIC, ClipMode.NONE, False, - torch.tensor([[7.718753843, 0.243110357, 0.790108661]], dtype=torch.float32)), + torch.tensor([[7.716609268, 0.243042812, 0.789889138]], dtype=torch.float32)), (LinearQuantMode.SYMMETRIC, ClipMode.NONE, True, - torch.tensor([[7.718753843, 0.243110357, 0.790108661]], dtype=torch.float32)) + torch.tensor([[7.716609268, 0.243042812, 0.789889138]], dtype=torch.float32)) ] ) def test_linear_layer_wrapper(linear_input, linear_weights, linear_bias, @@ -199,8 +199,8 @@ def test_linear_layer_wrapper(linear_input, linear_weights, linear_bias, linear_input = attach_quant_metadata(linear_input, 8, mode, stats=None, clip_mode=clip_acts, per_channel=False, num_stds=None, scale_approx_mult_bits=None) - with pytest.raises(RuntimeError): - model(linear_input) + # with pytest.raises(RuntimeError): + # model(linear_input) model.eval() @@ -688,34 +688,42 @@ def test_acts_quant_params_linear(act1_type, act2_type, bn_out_stats): model = LinearBNSplitAct(act1_type, act2_type) stats = gen_stats_for_model(model) stats['bn']['output'] = bn_out_stats - quantizer = PostTrainLinearQuantizer(model, model_activation_stats=deepcopy(stats)) + quantizer = PostTrainLinearQuantizer(model, model_activation_stats=deepcopy(stats), save_fp_weights=True) quantizer.prepare_model(torch.randn(10, 10)) # get quant params: expected_quant_params_keys = { 'linear.output_zero_point', 'linear.output_scale', + 'linear.w_scale', + 'linear.w_zero_point', 'act1.output_zero_point', 'act1.output_scale', 'act2.output_zero_point', 'act2.output_scale' } - assert set(quantizer.acts_quant_params) == expected_quant_params_keys - quantizer.set_act_quant_param('linear.output_zero_point', 2.) - quantizer.set_act_quant_param('linear.output_scale', 30.) + assert set(quantizer.linear_quant_params) == expected_quant_params_keys + quantizer.set_linear_quant_param('linear.output_zero_point', 2.) + quantizer.set_linear_quant_param('linear.output_scale', 30.) assert model.linear.output_zero_point == 2. assert model.linear.output_scale == 30. + assert model.linear.force_readjust == True + assert model.act1.force_readjust == True expected_quant_param_linear_dict = { 'output_zero_point': torch.tensor(2.), - 'output_scale': 30. + 'output_scale': 30., + 'w_scale': model.linear.w_scale.item(), + 'w_zero_point': model.linear.w_zero_point.item() } - assert dict(model.linear.named_acts_quant_params()) == expected_quant_param_linear_dict + assert dict(model.linear.named_linear_quant_params()) == expected_quant_param_linear_dict new_config = { 'linear.output_zero_point': 4., 'act2.output_scale': 50 } - quantizer.update_acts_quant_params(new_config) + quantizer.update_linear_quant_params(new_config) assert model.linear.output_zero_point == 4 assert model.act2.output_scale == 50 + assert model.linear.force_readjust == True + assert model.act1.force_readjust == True class DummyWordLangModel(nn.Module): @@ -732,18 +740,20 @@ class DummyWordLangModel(nn.Module): @pytest.mark.filterwarnings('ignore:Iterating over a tensor might cause the trace to be incorrect') @pytest.mark.filterwarnings('ignore:Converting a tensor to a Python index might cause the trace to be incorrect') def test_acts_quant_params_rnn(rnn_model): - model = DummyWordLangModel(nn.Embedding(41, 20), rnn_model).cuda() + model = DummyWordLangModel(nn.Embedding(41, 20), rnn_model) stats = gen_stats_for_model(model) quantizer = PostTrainLinearQuantizer(model, model_activation_stats=deepcopy(stats)) - dummy_input = torch.randint(0, 41, size=(79, 23)) + dummy_input = torch.randint(0, 41, size=(10, 1)) quantizer.prepare_model(dummy_input) new_config = { 'rnn.rnn.cells.0.act_o.output_scale': 4, 'embedding.w_scale': torch.tensor(59.0) } - quantizer.update_acts_quant_params(new_config) + quantizer.update_linear_quant_params(new_config) assert model.rnn.rnn.cells[0].act_o.output_scale == 4 assert model.embedding.w_scale == 59.0 + assert model.rnn.rnn.cells[0].act_o.force_readjust.item() is True + assert model.rnn.rnn.cells[0].act_f.force_readjust.item() is True ############################################################################### @@ -767,20 +777,21 @@ def _fake_quant_tensor(tensor, n_bits, mode, per_channel): q_utils.linear_dequantize(tensor, scale, zp, inplace=True) -def _test_wts_only_quant(layer, x, per_channel, bias, num_bits): +def _test_wts_only_quant(layer, x, per_channel, bias, num_bits_wts, num_bits_accum): layer.weight.data = torch.rand_like(layer.weight) if bias: layer.bias.data = torch.rand_like(layer.bias) mode = LinearQuantMode.ASYMMETRIC_UNSIGNED - layer_ptq = RangeLinearQuantParamLayerWrapper(deepcopy(layer), None, num_bits, mode=mode, per_channel_wts=per_channel) + layer_ptq = RangeLinearQuantParamLayerWrapper(deepcopy(layer), None, num_bits_wts, num_bits_accum=num_bits_accum, + mode=mode, per_channel_wts=per_channel) layer_ptq.eval() layer_manual_q = deepcopy(layer) - _fake_quant_tensor(layer_manual_q.weight.data, num_bits, mode, per_channel) + _fake_quant_tensor(layer_manual_q.weight.data, num_bits_wts, mode, per_channel) assert torch.equal(layer_ptq.wrapped_module.weight, layer_manual_q.weight) if bias: - _fake_quant_tensor(layer_manual_q.bias.data, num_bits, mode, False) + _fake_quant_tensor(layer_manual_q.bias.data, num_bits_accum, mode, False) assert torch.equal(layer_ptq.wrapped_module.bias, layer_manual_q.bias) y_ptq = layer_ptq(x) @@ -795,7 +806,7 @@ def test_conv_layer_wrapper_params_only(per_channel, bias): layer = torch.nn.Conv2d(in_ch, 10, 3, bias=bias) x = torch.rand(5, in_ch, 5, 5) - _test_wts_only_quant(layer, x, per_channel, bias, 8) + _test_wts_only_quant(layer, x, per_channel, bias, 8, 32) def test_linear_layer_wrapper_params_only(per_channel, bias): @@ -805,4 +816,4 @@ def test_linear_layer_wrapper_params_only(per_channel, bias): x = torch.rand(5, in_features) - _test_wts_only_quant(layer, x, per_channel, bias, 8) + _test_wts_only_quant(layer, x, per_channel, bias, 8, 32) -- GitLab