diff --git a/distiller/data_loggers/collector.py b/distiller/data_loggers/collector.py index 612f7d3f5646c31710a0a83ef7053d87904541b6..7fef7487a45363ad314567d95413a849fe700e41 100755 --- a/distiller/data_loggers/collector.py +++ b/distiller/data_loggers/collector.py @@ -82,14 +82,20 @@ class ActivationStatsCollector(object): self.model.apply(partial(self._collect_activations_stats, activation_stats=activation_stats)) return activation_stats - def start(self): + def start(self, modules_list=None): """Start collecting activation stats. This will iteratively register the modules' forward-hooks, so that the collector will be called from the forward traversal and get exposed to activation data. + modules_list (iterable): track stats for modules in the list. If None/empty - will track for all modules. """ assert len(self.fwd_hook_handles) == 0 - self.model.apply(self.start_module) + if not modules_list: + self.model.apply(self.start_module) + return + modules_dict = dict(self.model.named_modules()) + for module_name in modules_list: + modules_dict[module_name].apply(self.start_module) def start_module(self, module): """Iteratively register to the forward-pass callback of all eligible modules. @@ -707,7 +713,8 @@ class ActivationHistogramsCollector(ActivationStatsCollector): def collect_quant_stats(model, test_fn, save_dir=None, classes=None, inplace_runtime_check=False, - disable_inplace_attrs=False, inplace_attr_names=('inplace',)): + disable_inplace_attrs=False, inplace_attr_names=('inplace',), + modules_to_collect=None): """ Helper function for collecting quantization calibration statistics for a model using QuantCalibrationStatsCollector @@ -722,6 +729,8 @@ def collect_quant_stats(model, test_fn, save_dir=None, classes=None, inplace_run inplace_runtime_check (bool): See QuantCalibrationStatsCollector disable_inplace_attrs (bool): See QuantCalibrationStatsCollector inplace_attr_names (iterable): See QuantCalibrationStatsCollector + modules_to_collect (iterable): enable stats collection for a predefined modules (specified by names). + if None - will track stats for all layers. Returns: Dictionary with quantization stats (see QuantCalibrationStatsCollector for a description of the dictionary @@ -732,7 +741,7 @@ def collect_quant_stats(model, test_fn, save_dir=None, classes=None, inplace_run inplace_runtime_check=inplace_runtime_check, disable_inplace_attrs=disable_inplace_attrs, inplace_attr_names=inplace_attr_names) - with collector_context(quant_stats_collector): + with collector_context(quant_stats_collector, modules_to_collect): msglogger.info('Pass 1: Collecting min, max, avg_min, avg_max, mean, std') test_fn(model=model) # Collect Laplace distribution stats: @@ -799,10 +808,10 @@ def collect_histograms(model, test_fn, save_dir=None, activation_stats=None, @contextmanager -def collector_context(collector): +def collector_context(collector, modules_list=None): """A context manager for an activation collector""" if collector is not None: - collector.reset().start() + collector.reset().start(modules_list) yield collector if collector is not None: collector.stop() diff --git a/distiller/model_transforms.py b/distiller/model_transforms.py index dca01f443d490684b6e538a8f095f53873271fbd..5761431bb02a48b50402dc174193f3b655776b0c 100644 --- a/distiller/model_transforms.py +++ b/distiller/model_transforms.py @@ -24,7 +24,7 @@ import logging msglogger = logging.getLogger() -__all__ = ["fuse_modules", "fold_batch_norms_inference"] +__all__ = ["fuse_modules", "fold_batch_norms"] def fuse_modules(model, types_sequence, fuse_fn, dummy_input=None, adjacency_map=None): @@ -98,7 +98,7 @@ def fuse_modules(model, types_sequence, fuse_fn, dummy_input=None, adjacency_map return model -def fold_batch_norms_inference(model, dummy_input=None, adjacency_map=None): +def fold_batch_norms(model, dummy_input=None, adjacency_map=None, inference=True): """Scans the model for convolution / linear modules followed by batch-normalization. For each such valid pair, folds the parameters of the batch normalization module into the parameters of the parameter module, and replaces the batch normalization module with an identity operation. @@ -112,6 +112,8 @@ def fold_batch_norms_inference(model, dummy_input=None, adjacency_map=None): adjacency_map (OrderedDict): Pre-computed adjacency map, via SummaryGraph.adjacency_map(). Must be based on the passed model, otherwise results are unexpected. If None, then the adjacency map will be created internally using the passed dummy_input. + inference (bool): an indicator on whether or not the modules are in inference mode. + This will hard-fuse all BatchNorms into the param-layers. """ def fold_bn(sequence): # Re-use this functionality from simulated BN folding implementation @@ -121,8 +123,10 @@ def fold_batch_norms_inference(model, dummy_input=None, adjacency_map=None): except ValueError: msglogger.debug("Can't fold, {} does not track running stats".format(bn_module.distiller_name)) return None - folded_module.freeze() - return folded_module.param_module + if inference: + folded_module.freeze() + return folded_module.param_module + return folded_module foldables = (nn.Linear, nn.Conv1d, nn.Conv2d, nn.Conv3d) batchnorms = (nn.BatchNorm1d, nn.BatchNorm2d, nn.BatchNorm3d) diff --git a/distiller/models/__init__.py b/distiller/models/__init__.py index bf628db183fdfbc8187dd95ba543e8c3f3cb0c17..bedc28a5f12c7f1252150636da8aaf1aadd8be2e 100755 --- a/distiller/models/__init__.py +++ b/distiller/models/__init__.py @@ -20,12 +20,14 @@ import copy from functools import partial import torch import torchvision.models as torch_models +import torch.nn as nn from . import cifar10 as cifar10_models from . import mnist as mnist_models from . import imagenet as imagenet_extra_models import pretrainedmodels from distiller.utils import set_model_input_shape_attr +from distiller.modules import Mean, EltwiseAdd import logging msglogger = logging.getLogger() @@ -59,16 +61,41 @@ ALL_MODEL_NAMES = sorted(map(lambda s: s.lower(), set(IMAGENET_MODEL_NAMES + CIFAR10_MODEL_NAMES + MNIST_MODEL_NAMES))) -# A temporary monkey-patch to get past this Torchvision bug: -# https://github.com/pytorch/pytorch/issues/20516 -def patch_torchvision_mobilenet_v2_bug(model): - def patched_forward(self, x): +def patch_torchvision_mobilenet_v2(model): + """ + Patches TorchVision's MobileNetV2: + * To allow quantization, this adds modules for tensor operations (mean, element-wise addition) to the + model instance and patches the forward functions accordingly + * Fixes a bug in the torchvision implementation that prevents export to ONNX (and creation of SummaryGraph) + """ + if not isinstance(model, torch_models.MobileNetV2): + raise TypeError("Only MobileNetV2 is acceptable.") + + def patched_forward_mobilenet_v2(self, x): x = self.features(x) - #x = x.mean([2, 3]) - x = x.mean(3).mean(2) + # x = x.mean([2, 3]) # this was a bug: https://github.com/pytorch/pytorch/issues/20516 + x = self.mean32(x) x = self.classifier(x) return x - model.__class__.forward = patched_forward + model.mean32 = nn.Sequential( + Mean(3), Mean(2) + ) + model.__class__.forward = patched_forward_mobilenet_v2 + + def is_inverted_residual(module): + return isinstance(module, nn.Module) and module.__class__.__name__ == 'InvertedResidual' + + def patched_forward_invertedresidual(self, x): + if self.use_res_connect: + return self.residual_eltwiseadd(self.conv(x), x) + else: + return self.conv(x) + + for n, m in model.named_modules(): + if is_inverted_residual(m): + if m.use_res_connect: + m.residual_eltwiseadd = EltwiseAdd() + m.__class__.forward = patched_forward_invertedresidual _model_extensions = {} @@ -137,7 +164,7 @@ def _create_imagenet_model(arch, pretrained): try: model = getattr(torch_models, arch)(pretrained=pretrained) if arch == "mobilenet_v2": - patch_torchvision_mobilenet_v2_bug(model) + patch_torchvision_mobilenet_v2(model) except NotImplementedError: # In torchvision 0.3, trying to download a model that has no # pretrained image available will raise NotImplementedError diff --git a/distiller/models/cifar10/resnet_cifar.py b/distiller/models/cifar10/resnet_cifar.py index e9ce4e55ab026ba545223d42bce061328f4105d3..ca31731feaa115a21420b80973e6b23d1fd1f09f 100755 --- a/distiller/models/cifar10/resnet_cifar.py +++ b/distiller/models/cifar10/resnet_cifar.py @@ -37,6 +37,7 @@ This ResNet also has layer gates, to be able to dynamically remove layers. import torch.nn as nn import math import torch.utils.model_zoo as model_zoo +from distiller.modules import EltwiseAdd __all__ = ['resnet20_cifar', 'resnet32_cifar', 'resnet44_cifar', 'resnet56_cifar'] @@ -62,6 +63,7 @@ class BasicBlock(nn.Module): self.relu2 = nn.ReLU(inplace=False) self.downsample = downsample self.stride = stride + self.residual_eltwiseadd = EltwiseAdd() def forward(self, x): residual = out = x @@ -78,7 +80,7 @@ class BasicBlock(nn.Module): if self.downsample is not None: residual = self.downsample(x) - out += residual + out = self.residual_eltwiseadd(residual, out) out = self.relu2(out) return out diff --git a/distiller/modules/__init__.py b/distiller/modules/__init__.py index 03c3f57f4898a98cbe3f8649edb86aa0b3bf855b..71c72eef6d8cd75b5c4dc83ca9d4fc5a881bb8ac 100644 --- a/distiller/modules/__init__.py +++ b/distiller/modules/__init__.py @@ -18,9 +18,9 @@ from .eltwise import * from .grouping import * from .matmul import * from .rnn import * -from .aggregate import Norm +from .aggregate import * __all__ = ['EltwiseAdd', 'EltwiseMult', 'EltwiseDiv', 'Matmul', 'BatchMatmul', 'Concat', 'Chunk', 'Split', 'Stack', 'DistillerLSTMCell', 'DistillerLSTM', 'convert_model_to_distiller_lstm', - 'Norm'] + 'Norm', 'Mean'] diff --git a/distiller/modules/aggregate.py b/distiller/modules/aggregate.py index a217627440cb1d929f44a77a37a31de5df55b29a..f6f6cd469d4c215afd120b378d8cb11048050f0f 100644 --- a/distiller/modules/aggregate.py +++ b/distiller/modules/aggregate.py @@ -14,3 +14,13 @@ class Norm(nn.Module): def forward(self, x: torch.Tensor): return torch.norm(x, p=self.p, dim=self.dim, keepdim=self.keepdim) + + +class Mean(nn.Module): + def __init__(self, *args, **kwargs): + super(Mean, self).__init__() + self.args = args + self.kwargs = kwargs + + def forward(self, x: torch.Tensor): + return torch.mean(x, *self.args, **self.kwargs) diff --git a/distiller/quantization/ptq_greedy_search.py b/distiller/quantization/ptq_greedy_search.py new file mode 100644 index 0000000000000000000000000000000000000000..1f361febff5580555f179ee951626ff86602dd8d --- /dev/null +++ b/distiller/quantization/ptq_greedy_search.py @@ -0,0 +1,488 @@ +# +# Copyright (c) 2018 Intel Corporation +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +r""" +Here we implement the greedy search algorithm for automatic quantization. +""" +import torch +import torch.nn as nn +from distiller.quantization.range_linear import PostTrainLinearQuantizer, ClipMode, LinearQuantMode +from distiller.summary_graph import SummaryGraph +from distiller.model_transforms import fold_batch_norms +import distiller.modules +from distiller.data_loggers import collect_quant_stats +from distiller.models import create_model +from collections import OrderedDict, defaultdict +import logging +from copy import deepcopy +import distiller.apputils.image_classifier as classifier +import os +import distiller.apputils as apputils +import re +import argparse + +__all__ = ['ptq_greedy_search'] + +msglogger = None + +QUANTIZED_MODULES = ( + nn.Linear, + nn.Conv2d, + nn.Conv3d, + distiller.modules.Concat, + distiller.modules.EltwiseAdd, + distiller.modules.EltwiseMult, + distiller.modules.Matmul, + distiller.modules.BatchMatmul +) + +FP16_LAYERS = ( + nn.Tanh, + nn.Sigmoid +) + +PARAM_MODULES = ( + nn.Linear, + nn.Conv2d, + nn.Conv3d +) + +UNQUANTIZED_MODULES = ( + nn.Softmax, +) + +SKIP_MODULES = ( + nn.Identity, + nn.BatchNorm1d, + nn.BatchNorm2d, + nn.BatchNorm3d, + nn.ReLU, + nn.ReLU6 +) + +CLIP_MODES = ['NONE', + 'AVG', + 'GAUSS', + 'LAPLACE' + ] + + +def get_default_args(): + parser = classifier.init_classifier_compression_arg_parser() + parser.add_argument('--qe-no-quant-layers', '--qenql', type=str, nargs='+', metavar='LAYER_NAME', default=[], + help='List of layer names for which to skip quantization.') + parser.add_argument('--qe-calib-portion', type=float, default=1.0, + help='The portion of the dataset to use for calibration stats collection.') + parser.add_argument('--qe-calib-batchsize', type=int, default=256, + help='The portion of the dataset to use for calibration stats collection.') + parser.add_argument('--base-score', type=float, default=None) + parser.add_argument('--quantize-inputs', type=str, nargs='+', metavar='LAYER_NAME#INPUT_IDX', default=[], + help='The inputs of layers to quantize') + parser.add_argument('--resume-search-from', type=str, help='Search checkpoint file to resume.', + default=None) + args = parser.parse_args() + return args + + +def override_odict(**kwargs): + return OrderedDict(kwargs) + + +def get_inputs_to_quantize(sg, args, recurrent=False): + """ + Finds the modules in the graph that take the user input directly + Args: + sg (SummaryGraph): the summary graph of the model + recurrent: see SummaryGraph.layers_topological_order + TODO - implement properly + """ + # input_modules = set() + # layers = set(sg.layers_topological_order(recurrent)) + # for op in sg.top_level_ops(): + # input_modules.update(set(sg.successors(op, 1)) & layers) + # return list(input_modules) + result = defaultdict(lambda: []) + for input_str in args.quantize_inputs: + module_name, input_idx_str = input_str.split('#') + input_idx = int(input_idx_str) + result[module_name].append(input_idx) + return result + + +def input_override_generator(module, module_name, sg, overrides_dict, **kwargs): + """ + Generator for overrides on inputs of the input layers. + Args: + module (nn.Module): the module + module_name (str): module name as it appears in the summary graph + sg (SummaryGraph): a summary graph of the model + overrides_dict (OrderedDict): the fixed overrides already applied + kwargs: additional arguments, if needed + """ + bits_acts = kwargs.get('bits_activations', 8) + bits_wts = kwargs.get('bits_weights', 8) + input_nodes = sg.predecessors(module_name, 1) + input_idx = kwargs.get('input_idx', 0) + assert input_idx < len(input_nodes) + for clip_mode in CLIP_MODES: + input_idx_override = override_odict(bits_activations=bits_acts, + clip_acts=clip_mode) + input_overrides = OrderedDict([(input_idx, input_idx_override)]) + current_module_override = override_odict(input_overrides=input_overrides) + # add basic quantization so the quantizer doesn't reject this override + current_module_override['bits_activations'] = bits_acts + if isinstance(module, PARAM_MODULES): + current_module_override['bits_weights'] = bits_wts + yield current_module_override + + +def module_override_generator(module, module_name, sg, overrides_dict, **kwargs): + """ + Standard generator of overrides for the greedy search algorithm. + Args: + module (nn.Module): the module + module_name (str): module name as it appears in the summary graph + sg (SummaryGraph): a summary graph of the model + overrides_dict (OrderedDict): the fixed overrides already applied + kwargs: additional arguments, if needed + """ + bits_acts = kwargs.get('bits_activations', 8) + bits_wts = kwargs.get('bits_weights', 8) + if isinstance(module, nn.ReLU): + yield override_odict(make_identity=True, + bits_weights=bits_wts, + bits_activations=bits_acts) + return + adj_map = sg.adjacency_map() + modules_dict = dict(sg._src_model.named_modules()) + successors_names = {op.name for op in adj_map[module_name].successors if op.name in modules_dict} + use_half_range = all([isinstance(modules_dict[succ], nn.ReLU) for succ in successors_names]) + use_fake = False + fpq_module = None + if isinstance(module, FP16_LAYERS): + fpq_module = 16 + use_fake = True + if isinstance(module, UNQUANTIZED_MODULES) or not isinstance(module, QUANTIZED_MODULES): + fpq_module = 32 + use_fake = True + for clip_mode in CLIP_MODES: + if isinstance(module, PARAM_MODULES): + current_module_override = override_odict(clip_acts=clip_mode, + bits_weights=bits_wts, + bits_activations=bits_acts, + bits_bias=32) + else: + current_module_override = override_odict(clip_acts=clip_mode, + fpq_module=fpq_module, + fake=use_fake, + bits_weights=bits_wts, + bits_activations=bits_acts) + current_module_override['clip_half_range'] = use_half_range and clip_mode in ['GAUSS', 'LAPLACE'] + + yield current_module_override + + +def search_best_local_settings(module, module_name, sg, act_stats, eval_fn, best_overrides_dict, override_gen_fn, + **kwargs): + msglogger.info('Searching optimal quantization in \'%s\'(%s):' % (module_name, module.__class__.__name__)) + overrides_dict = deepcopy(best_overrides_dict) + best_performance, best_local_override = float("-inf"), OrderedDict() + normalized_module_name = module_name + if isinstance(model, nn.DataParallel): + normalized_module_name = re.sub(r'module\.', '', normalized_module_name) + for local_override in override_gen_fn(module, module_name, sg, best_overrides_dict, **kwargs): + if not overrides_dict.get(normalized_module_name, None): + overrides_dict[normalized_module_name] = OrderedDict() + overrides_dict[normalized_module_name].update(local_override) + temp_act_stats = deepcopy(act_stats) + quantizer = PostTrainLinearQuantizer(deepcopy(model), + bits_activations=None, + bits_parameters=None, + bits_accum=32, + mode=LinearQuantMode.ASYMMETRIC_SIGNED, + clip_acts=ClipMode.NONE, + overrides=deepcopy(overrides_dict), + model_activation_stats=deepcopy(temp_act_stats), + inputs_quant_auto_fallback=False, + per_channel_wts=kwargs.get('per_channel', False)) + quantizer.prepare_model(dummy_input) + + current_performance = eval_fn(quantizer.model) + if not isinstance(module, UNQUANTIZED_MODULES): + clip_mode = local_override.get('clip_acts', None) + msglogger.info('\t%s\t score = %.3f\tLayer overrides: %s' % + (clip_mode or '', current_performance, local_override)) + else: + msglogger.info('\t Module is not quantized to int8. Not clipping activations.') + msglogger.info('\t score = %.3f\tLayer overrides: %s' % + (current_performance, local_override)) + if current_performance > best_performance: + best_performance = current_performance + best_local_override = local_override + + msglogger.info('\t Choosing overrides: %s' % best_local_override) + return best_local_override + + +def ptq_greedy_search(model, dummy_input, eval_fn, calib_eval_fn=None, + recurrent=False, act_stats=None, + args=None, + module_override_gen_fn=None, input_override_gen_fn=None, + fold_sequences=True): + """ + Perform greedy search on Post Train Quantization configuration for the model. + Args: + model (nn.Module): the model to quantize + dummy_input (torch.Tensor): a dummy input to be passed to the model + eval_fn (function): Test/Evaluation function for the model. It must have an argument named 'model' that + accepts the model. All other arguments should be set in advance (can be done using functools.partial), or + they will be left with their default values. + calib_eval_fn (function): An 'evaluation' function to use for forward passing + through the model to collection quantization calibration statistics. + if None provided - will use `eval_fn` as a default. + recurrent (bool): a flag to indicate whether the model has recurrent connections. + act_stats (OrderedDict): quant calibration activation stats. + if None provided - will be calculated on runtime. + args (dict or argparse.Namespace): command line arguments. alternatively - a dict. + module_override_gen_fn: A function to generate module overrides. + assumes signature + `def module_override_gen_fn(module: nn.Module, + module_name: str, + sg: distiller.SummaryGraph, + overrides_dict: OrderedDict, + **kwargs)-> Generator[OrderedDict, None, None]` + input_override_gen_fn: Same as module_override_gen_fn, only quantized inputs to the top level layers. + fold_sequences (bool): fold batch norms before quantizing + Returns: + (quantized_model, best_overrides_dict) + Note: + It is assumed that `eval_fn` returns a satisfying metric of performance (e.g. accuracy) + and the greedy search aims to maximize this metric. + """ + if args is None: + args = get_default_args() + elif isinstance(args, dict): + updated_args = get_default_args() + updated_args.__dict__.update(args) + args = updated_args + + if fold_sequences: + model = fold_batch_norms(model, dummy_input) + best_overrides_dict = OrderedDict() + if args.resume_search_from: + with open(args.resume_search_from, 'r') as f: + best_overrides_dict = distiller.yaml_ordered_load(f) + msglogger.info('Loaded search checkpoint from %s' % args.resume_search_from) + overrides_dict = OrderedDict() + sg = SummaryGraph(model, dummy_input) + modules_to_quantize = sg.layers_topological_order(recurrent) + adjacency_map = sg.adjacency_map() + modules_dict = OrderedDict(model.named_modules()) # type: OrderedDict[str, nn.Module] + modules_to_quantize = [m for m in modules_to_quantize + if m not in args.qe_no_quant_layers] + + module_override_gen_fn = module_override_gen_fn or module_override_generator + input_override_gen_fn = input_override_gen_fn or input_override_generator + + calib_eval_fn = calib_eval_fn or eval_fn + if not act_stats: + msglogger.info('Collecting stats for model...') + model_temp = distiller.utils.make_non_parallel_copy(model) + act_stats = collect_quant_stats(model_temp, calib_eval_fn) + del model_temp + if args: + act_stats_path = '%s_act_stats.yaml' % args.arch + msglogger.info('Done. Saving act stats into %s' % act_stats_path) + distiller.yaml_ordered_save(act_stats_path, act_stats) + msglogger.info('Evaluating baseline score for model...') + base_score = args.base_score or eval_fn(model) + msglogger.info("Base score: %.3f" % base_score) + + def recalibrate_stats(module_name, act_stats): + """ + Re-collects quant-calibration stats for successor modules of the current module. + """ + msglogger.info('Recalibrating stats...') + modules_to_recalibrate = {op.name for op in adjacency_map[module_name].successors} & set(act_stats) + if not modules_to_recalibrate: + # either there aren't any successors or + # the successors aren't in the stats file - skip + return act_stats + q = PostTrainLinearQuantizer(distiller.utils.make_non_parallel_copy(model), + bits_activations=None, + bits_parameters=None, + bits_accum=32, + mode=LinearQuantMode.ASYMMETRIC_SIGNED, + clip_acts=ClipMode.NONE, + overrides=deepcopy(best_overrides_dict), + model_activation_stats=deepcopy(act_stats), + inputs_quant_auto_fallback=False, + per_channel_wts=args.qe_per_channel) + q.prepare_model(dummy_input) + # recalibrate on the current best quantized version of the model. + recalib_act_stats = collect_quant_stats(q.model, calib_eval_fn, modules_to_collect=modules_to_recalibrate) + msglogger.info('Done.') + act_stats.update(recalib_act_stats) + return act_stats + + loaded_from_checkpoint = [] + # Quantize inputs: + input_modules = get_inputs_to_quantize(sg, args, recurrent) # top level modules + for module_name, input_idxs in input_modules.items(): + denormalized_module_name = distiller.denormalize_module_name(model, module_name) + module = modules_dict[denormalized_module_name] + if isinstance(module, SKIP_MODULES): + msglogger.info('Skipping module \'%s\' of type %s.' % (module_name, type(module))) + continue + msglogger.info('Quantizing top level inputs for %s' % module_name) + + normalized_module_name = module_name + if isinstance(model, nn.DataParallel): + normalized_module_name = re.sub(r'module\.', '', normalized_module_name) + if normalized_module_name in best_overrides_dict and \ + best_overrides_dict[normalized_module_name].get('input_overrides', None): + # This means the loaded dict already has the module + msglogger.info(" Quantizing '%s' based on loaded checkpoint: %s" % + (module_name, best_overrides_dict[normalized_module_name])) + if best_overrides_dict[normalized_module_name].get('bits_activations'): + loaded_from_checkpoint.append(normalized_module_name) + continue + if not best_overrides_dict.get(normalized_module_name, None): + best_overrides_dict[normalized_module_name] = OrderedDict() + for input_idx in input_idxs: + best_module_override = search_best_local_settings(module, module_name, sg, act_stats, eval_fn, + best_overrides_dict, + input_override_gen_fn, input_idx=input_idx, + bits_activations=args.qe_bits_acts, + bits_weights=args.qe_bits_wts, + per_channel=args.qe_per_channel) + best_overrides_dict[normalized_module_name].update(best_module_override) + # Leave only the input_overrides settings: + current_input_overrides = best_overrides_dict[normalized_module_name]['input_overrides'] + best_overrides_dict[normalized_module_name] = override_odict(input_overrides=current_input_overrides) + + # Quantize layers as a whole: + for module_name in modules_to_quantize: + module = modules_dict[module_name] + if isinstance(module, SKIP_MODULES): + msglogger.info('Skipping module \'%s\' of type %s.' % (module_name, module.__class__.__name__)) + continue + + normalized_module_name = module_name + if isinstance(model, nn.DataParallel): + normalized_module_name = re.sub(r'module\.', '', normalized_module_name) + + if normalized_module_name in best_overrides_dict and \ + best_overrides_dict[normalized_module_name].get('bits_activations', None)\ + and normalized_module_name not in loaded_from_checkpoint: + # This means the loaded dict already has the module + msglogger.info(" Quantizing '%s'(%s) based on loaded checkpoint: %s" % + (module_name, module.__class__.__name__, best_overrides_dict[normalized_module_name])) + loaded_from_checkpoint.append(normalized_module_name) + continue + if not best_overrides_dict.get(normalized_module_name, None): + best_overrides_dict[normalized_module_name] = OrderedDict() + # Hard coded workaround for avgpool->reshape->fc + if normalized_module_name == 'fc': + input_override = override_odict(bits_activations=8, + clip_acts='NONE') + best_overrides_dict['fc'].update(OrderedDict([ + ('input_overrides', OrderedDict([ + (0, input_override) + ])) + ])) + best_module_override = search_best_local_settings(module, module_name, sg, act_stats, eval_fn, + best_overrides_dict, + module_override_gen_fn, + bits_activations=args.qe_bits_acts, + bits_weights=args.qe_bits_wts, + per_channel=args.qe_per_channel) + best_overrides_dict[normalized_module_name].update(best_module_override) + distiller.yaml_ordered_save('%s.ptq_greedy_search.yaml' % args.arch, best_overrides_dict) + # # end of search - we update the calibration of the next layers: + # recalibrate_stats(module_name, act_stats) + + quantizer = PostTrainLinearQuantizer(model, + bits_activations=None, + bits_parameters=None, + bits_accum=32, + mode=LinearQuantMode.ASYMMETRIC_SIGNED, + clip_acts=ClipMode.NONE, + overrides=deepcopy(best_overrides_dict), + model_activation_stats=act_stats, + inputs_quant_auto_fallback=False, + per_channel_wts=args.qe_per_channel) + quantizer.prepare_model(dummy_input) + msglogger.info('best_overrides_dict: %s' % best_overrides_dict) + msglogger.info('Best score: %f'% eval_fn(quantizer.model)) + return model, best_overrides_dict + + +if __name__ == "__main__": + args = get_default_args() + args.epochs = float('inf') # hack for args parsing so there's no error in epochs + cc = classifier.ClassifierCompressor(args, script_dir=os.path.dirname(__file__)) + eval_data_loader = classifier.load_data(args, load_train=False, load_val=False) + + # quant calibration dataloader: + args.effective_test_size = args.qe_calib_portion + args.batch_size = args.qe_calib_batchsize + calib_data_loader = classifier.load_data(args, load_train=False, load_val=False) + # logging + logging.getLogger().setLevel(logging.WARNING) + msglogger = logging.getLogger(__name__) + msglogger.setLevel(logging.INFO) + + def test_fn(model): + top1, top5, losses = classifier.test(eval_data_loader, model, cc.criterion, [cc.tflogger, cc.pylogger], None, + args) + return top1 + + def calib_eval_fn(model): + classifier.test(calib_data_loader, model, cc.criterion, [], None, + args) + + model = create_model(args.pretrained, args.dataset, args.arch, + parallel=not args.load_serialized, device_ids=args.gpus) + args.device = next(model.parameters()).device + if args.resumed_checkpoint_path: + args.load_model_path = args.resumed_checkpoint_path + if args.load_model_path: + msglogger.info("Loading checkpoint from %s" % args.load_model_path) + model = apputils.load_lean_checkpoint(model, args.load_model_path, + model_device=args.device) + dummy_input = torch.rand(*model.input_shape, device=args.device) + if args.qe_stats_file: + msglogger.info("Loading stats from %s" % args.qe_stats_file) + with open(args.qe_stats_file, 'r') as f: + act_stats = distiller.yaml_ordered_load(f) + else: + act_stats = None + model, overrides = ptq_greedy_search(model, dummy_input, test_fn, + calib_eval_fn=calib_eval_fn, args=args, + act_stats=act_stats) + # Prepare a compression scheduler yaml config file: + quantizer_dict = OrderedDict([ + ('class', 'PostTrainLinearQuantizer') + ]) + quantizer_dict.update(deepcopy(model.quantizer_metadata['params'])) + quantizer_dict['overrides'] = overrides + quantizer_dict['model_activation_stats'] = os.path.abspath('%s_act_stats.yaml' % args.arch) + sched_dict = OrderedDict([ + ('quantizers', OrderedDict([ + ('post_train_quantizer', quantizer_dict) + ])) + ]) + distiller.yaml_ordered_save('%s.ptqgs_quantizer_sched_dict.yaml' % args.arch, sched_dict) diff --git a/distiller/quantization/quantizer.py b/distiller/quantization/quantizer.py index 2b3ac788e21b034fd8d87d38ea464825926b6309..84d0a246c44bf0b784a9bb58bd2228924fe78681 100644 --- a/distiller/quantization/quantizer.py +++ b/distiller/quantization/quantizer.py @@ -102,6 +102,10 @@ class Quantizer(object): 3.1 Gradients calculated with respect to q_weights 3.2 We also back-prop through the 'quantize' operation from step 1 4. Update fp_weights with gradients calculated in step 3.2 + Note: + The `overrides` dictionary assumes the keys are *not* the module names in the + `nn.DataParallel` case - i.e. without the `module.` prefix. e.g.: + module.conv1 -> OrderedDict([('conv1', OrderedDict(...))]) """ def __init__(self, model, optimizer=None, bits_activations=None, bits_weights=None, bits_bias=None, @@ -147,7 +151,7 @@ class Quantizer(object): regex_overrides_str = '|'.join(['(^{0}$)'.format(pattern) for pattern in patterns]) regex_overrides = re.compile(regex_overrides_str) - self.module_qbits_map = {} + self.module_qbits_map = {} # type: OrderedDict[str, QBits] self.module_overrides_map = {} for module_full_name, module in model.named_modules(): # Need to account for scenario where model is parallelized with DataParallel, which wraps the original @@ -171,7 +175,8 @@ class Quantizer(object): # Mapping from module type to function generating a replacement module suited for quantization # To be populated by child classes # Unspecified layer types return None by default. - self.replacement_factory = defaultdict(lambda: None) + self.replacement_factory = OrderedDict([(nn.Identity, None)]) + self.default_repalcement_fn = None # Pointer to parameters quantization function, triggered during training process # To be populated by child classes self.param_quantization_fn = None @@ -252,6 +257,8 @@ class Quantizer(object): # Re-transfer model to the device it was on, in case the quantizer created new parameters/buffers self.model.to(model_device) + distiller.assign_layer_fq_names(self.model) + msglogger.info('Quantized model:\n\n{0}\n'.format(self.model)) def _pre_prepare_model(self, dummy_input): @@ -281,15 +288,15 @@ class Quantizer(object): replace_msg(full_name) continue current_qbits = self.module_qbits_map[full_name] - if current_qbits.acts is None and current_qbits.wts is None: - if self.module_overrides_map[full_name]: - raise ValueError("Adding overrides while not quantizing is not allowed.") + # TODO - Review necessity of the block below + if current_qbits.acts is None and current_qbits.wts is None and not self.module_overrides_map[full_name]: # We indicate this module wasn't replaced by a wrapper replace_msg(full_name) self.modules_processed[module] = full_name, None else: # We use a type hint comment to let IDEs know replace_fn is a function - replace_fn = self.replacement_factory[type(module)] # type: Optional[Callable] + replace_fn = self.replacement_factory.get(type(module), + self.default_repalcement_fn) # type: Optional[Callable] # If the replacement function wasn't specified - continue without replacing this module. if replace_fn is not None: valid_kwargs, invalid_kwargs = distiller.filter_kwargs(self.module_overrides_map[full_name], @@ -299,16 +306,21 @@ class Quantizer(object): as override arguments for %s. Allowed kwargs: %s""" % (type(self), list(invalid_kwargs), type(module), list(valid_kwargs))) new_module = replace_fn(module, full_name, self.module_qbits_map, **valid_kwargs) - replace_msg(full_name, (module, new_module)) - # Add to history of prepared submodules - self.modules_processed[module] = full_name, new_module - setattr(container, name, new_module) - - # If a "leaf" module was replaced by a container, add the new layers to the QBits mapping - if not distiller.has_children(module) and distiller.has_children(new_module): - for sub_module_name, sub_module in new_module.named_modules(): - self._add_qbits_entry(full_name + '.' + sub_module_name, type(sub_module), current_qbits) - self.module_qbits_map[full_name] = QBits(acts=current_qbits.acts, wts=None, bias=None) + if new_module != module: + replace_msg(full_name, (module, new_module)) + # Add to history of prepared submodules + self.modules_processed[module] = full_name, new_module + setattr(container, name, new_module) + + # If a "leaf" module was replaced by a container, add the new layers to the QBits mapping + if not distiller.has_children(module) and distiller.has_children(new_module): + for sub_module_name, sub_module in new_module.named_modules(): + self._add_qbits_entry(full_name + '.' + sub_module_name, type(sub_module), + current_qbits) + self.module_qbits_map[full_name] = QBits(acts=current_qbits.acts, wts=None, bias=None) + else: + replace_msg(full_name) + self.modules_processed[module] = full_name, None if distiller.has_children(module): # For container we call recursively diff --git a/distiller/quantization/range_linear.py b/distiller/quantization/range_linear.py index 58b4a187d7afa1ab305599d14d94769d8492ac50..703e8721848e36d9d4803021849ffe42082eef15 100644 --- a/distiller/quantization/range_linear.py +++ b/distiller/quantization/range_linear.py @@ -15,9 +15,10 @@ # import torch.nn as nn +import torch.nn.functional as f import argparse from enum import Enum -from collections import OrderedDict +from collections import OrderedDict, namedtuple from functools import reduce, partial, update_wrapper import logging import os @@ -26,15 +27,28 @@ import warnings import distiller import distiller.utils -from .quantizer import Quantizer +from .quantizer import Quantizer, QBits from .q_utils import * +from .sim_bn_fold import SimulatedFoldedBatchNorm import distiller.modules import distiller.model_transforms as mt msglogger = logging.getLogger() +def _quant_param_to_str(val): + if isinstance(val, torch.Tensor): + if val.numel() > 1: + return 'PerCh' + else: + return '{:.6f}'.format(val.item()) + return '{:.6f}'.format(val) + + def _enum_to_str(enum_val): + # TODO: This can probably be removed + if isinstance(enum_val, str): # temporary fix + return enum_val return str(enum_val).split('.')[1] @@ -51,7 +65,6 @@ class ClipMode(Enum): AVG = 1 # Clip value calculated as mean of tensor + N standard deviations. N should be specified separately N_STD = 2 - # ACIQ Clipping Modes - GAUSS = 3 LAPLACE = 4 @@ -95,6 +108,7 @@ def _get_saturation_fn(quant_mode, clip_mode, num_stds, num_bits=None): return fns[clip_mode] +# TODO: Move to q_utils, add tests def _get_quant_params_from_tensor(tensor, num_bits, mode, clip=ClipMode.NONE, per_channel=False, num_stds=None, half_range=False, scale_approx_mult_bits=None): if per_channel and tensor.dim() not in [2, 4]: @@ -105,9 +119,6 @@ def _get_quant_params_from_tensor(tensor, num_bits, mode, clip=ClipMode.NONE, pe raise ValueError('N_STD clipping not supported with per-channel quantization') if num_stds is None: raise ValueError('Clip mode set top N_STD but \'num_stds\' parameter not provided') - if half_range and clip not in [ClipMode.GAUSS, ClipMode.LAPLACE]: - warnings.warn("Using clip_half_range without ACIQ clip modes (GAUSS or LAPACE) will have no" - " effect.") dim = 0 if clip == ClipMode.AVG or per_channel else None sat_fn = _get_saturation_fn(mode, clip, num_stds, num_bits) @@ -132,6 +143,7 @@ def _get_quant_params_from_tensor(tensor, num_bits, mode, clip=ClipMode.NONE, pe return scale, zp +# TODO: Move to q_utils, add tests def _get_quant_params_from_stats_dict(stats, num_bits, mode, clip=ClipMode.NONE, num_stds=None, half_range=False, scale_approx_mult_bits=None): if clip == ClipMode.N_STD: @@ -139,9 +151,6 @@ def _get_quant_params_from_stats_dict(stats, num_bits, mode, clip=ClipMode.NONE, raise ValueError('Clip mode set to N_STD but \'num_stds\' parameter not provided') if num_stds <= 0: raise ValueError('n_stds must be > 0, got {}'.format(num_stds)) - if half_range and clip not in [ClipMode.GAUSS, ClipMode.LAPLACE]: - warnings.warn("Using clip_half_range without ACIQ clip modes (GAUSS or LAPACE) will have no" - " effect.") prefix = 'avg_' if clip == ClipMode.AVG else '' sat_min = torch.tensor(float(stats[prefix + 'min'])) @@ -174,6 +183,39 @@ def _get_quant_params_from_stats_dict(stats, num_bits, mode, clip=ClipMode.NONE, # Post Training ############################################################################### +class TensorQuantMetadata(namedtuple('TensorQuantMetadata', ['scale', 'zero_point', 'min_q_val', 'max_q_val'])): + __slots__ = () + + def __str__(self): + return '(scale={} ; zero_point={})'.format(_quant_param_to_str(self.scale), + _quant_param_to_str(self.zero_point)) + + +class QuantSettings(object): + def __init__(self, num_bits, quant_mode, clip_mode, clip_n_stds, clip_half_range, per_channel): + self.num_bits = num_bits + self.quant_mode = quant_mode + self.clip_mode = clip_mode + self.clip_n_stds = clip_n_stds + self.clip_half_range = clip_half_range + self.per_channel = per_channel + + def __str__(self): + return '(num_bits={} ; quant_mode={} ; clip_mode={} ; clip_n_stds={} ; clip_half_range={}' \ + ' ; per_channel={})'.format(self.num_bits, _enum_to_str(self.quant_mode), + _enum_to_str(self.clip_mode), self.clip_n_stds, self.clip_half_range, + self.per_channel + ) + + +def linear_quantize_clamp_with_metadata(t, inplace=False): + return linear_quantize_clamp(t, *t.quant_metadata, inplace) + + +def linear_dequantize_with_metadata(t, inplace=False): + qmd = t.quant_metadata + return linear_dequantize(t, qmd.scale, qmd.zero_point, inplace) + def add_post_train_quant_args(argparser): str_to_quant_mode_map = {'sym': LinearQuantMode.SYMMETRIC, @@ -254,17 +296,42 @@ class RangeLinearQuantWrapper(nn.Module): def __init__(self, wrapped_module, num_bits_acts, num_bits_accum=32, mode=LinearQuantMode.SYMMETRIC, clip_acts=ClipMode.NONE, activation_stats=None, clip_n_stds=None, clip_half_range=False, - scale_approx_mult_bits=None): + scale_approx_mult_bits=None, + input_overrides=None, requires_quantized_inputs=True, inputs_quant_auto_fallback=False): super(RangeLinearQuantWrapper, self).__init__() + input_overrides = input_overrides or OrderedDict() + self.wrapped_module = wrapped_module - self.num_bits_acts = num_bits_acts - self.num_bits_accum = num_bits_accum - self.mode = mode - self.clip_acts = clip_acts - self.clip_n_stds = clip_n_stds self.clip_half_range = clip_half_range self.scale_approx_mult_bits = scale_approx_mult_bits + self.requires_quantized_inputs = requires_quantized_inputs + self.inputs_quant_auto_fallback = inputs_quant_auto_fallback + + self.output_quant_settings = QuantSettings(num_bits_acts, mode, clip_acts, clip_n_stds, clip_half_range, False) + self.accum_quant_settings = QuantSettings(num_bits_accum, LinearQuantMode.SYMMETRIC, + ClipMode.NONE, None, False, False) + if self.requires_quantized_inputs: + self.inputs_quant_settings_overrides = OrderedDict() + for k, v in input_overrides.items(): + idx = int(k) + if v.pop('from_output', None): + quant_settings = deepcopy(self.output_quant_settings) + quant_settings.clip_half_range = False + else: + quant_settings = QuantSettings(v.pop('bits_activations', self.output_quant_settings.num_bits), + verify_quant_mode( + v.pop('mode', self.output_quant_settings.quant_mode)), + verify_clip_mode( + v.pop('clip_acts', self.output_quant_settings.clip_mode)), + v.pop('clip_n_stds', self.output_quant_settings.clip_n_stds), + False, False) + if v: + # Poor man's input checking on input overrides dict + raise ValueError('Input overrides dict contains unsupported keys:', list(v.keys())) + self.inputs_quant_settings_overrides[idx] = quant_settings + else: + self.inputs_quant_settings_overrides = None # Controls whether output is de-quantized at end of forward op. Meant as a debug / test flag only # (note that if False, the quantized output will be returned, but without any quantization parameters, @@ -278,58 +345,59 @@ class RangeLinearQuantWrapper(nn.Module): if activation_stats: self.preset_act_stats = True - self.num_inputs = 0 - for idx, stats in activation_stats['inputs'].items(): - self.num_inputs += 1 - scale, zp = _get_quant_params_from_stats_dict(stats, num_bits_acts, - mode, clip_acts, clip_n_stds, clip_half_range, - scale_approx_mult_bits) - prefix = 'in_{0}_'.format(idx) - self.register_buffer(prefix + 'scale', scale) - self.register_buffer(prefix + 'zero_point', zp) - - scale, zp = _get_quant_params_from_stats_dict(activation_stats['output'], num_bits_acts, mode, - clip_acts, clip_n_stds, clip_half_range, - scale_approx_mult_bits) + + if self.requires_quantized_inputs: + self.inputs_quant_metadata_fallback = OrderedDict() + for idx, stats in activation_stats['inputs'].items(): + settings = self.inputs_quant_settings_overrides.get(idx, self.output_quant_settings) + scale, zp = _get_quant_params_from_stats_dict( + stats, settings.num_bits, settings.quant_mode, settings.clip_mode, + settings.clip_n_stds, settings.clip_half_range, self.scale_approx_mult_bits + ) + min_q_val, max_q_val = get_quantized_range( + settings.num_bits, settings.quant_mode != LinearQuantMode.ASYMMETRIC_UNSIGNED) + qmd = TensorQuantMetadata(scale, zp, min_q_val, max_q_val) + self.inputs_quant_metadata_fallback[idx] = qmd + else: + self.inputs_quant_metadata_fallback = None + + scale, zp = _get_quant_params_from_stats_dict(activation_stats['output'], num_bits_acts, mode, clip_acts, + clip_n_stds, clip_half_range, scale_approx_mult_bits) self.register_buffer('output_scale', scale) self.register_buffer('output_zero_point', zp) else: self.preset_act_stats = False - def inputs_scales(self): - for scale in self._inputs_qparam('scale'): - yield scale - - def inputs_zero_points(self): - for zp in self._inputs_qparam('zero_point'): - yield zp - - def _inputs_qparam(self, type_str): - if type_str not in ['scale', 'zero_point']: - raise ValueError('Unknown quantization parameter type') - if not self.preset_act_stats: - raise RuntimeError('Input quantization parameter iterators only available when activation stats were given') - for idx in range(self.num_inputs): - name = 'in_{0}_{1}'.format(idx, type_str) - yield getattr(self, name) + self.register_buffer('num_forwards', torch.zeros(1, dtype=torch.long)) def forward(self, *inputs): if self.training: raise RuntimeError(self.__class__.__name__ + " can only be used in eval mode") + device = inputs[0].device for buffer_name, buffer in self._buffers.items(): setattr(self, buffer_name, buffer.to(device)) - in_scales, in_zero_points = self.get_inputs_quantization_params(*inputs) - - # Quantize inputs - inputs_q = [linear_quantize_clamp(input.data, scale, zp, - self.acts_min_q_val, self.acts_max_q_val, inplace=False) - for input, scale, zp in zip(inputs, in_scales, in_zero_points)] + if self.requires_quantized_inputs: + self._prepare_inputs_for_quantization(inputs) + + inputs_q = [] + for input in inputs: + qmd = input.quant_metadata + input.quant_metadata = TensorQuantMetadata(qmd.scale.to(device), qmd.zero_point.to(device), + qmd.min_q_val, qmd.max_q_val) + input_q = linear_quantize_clamp_with_metadata(input) + input_q.quant_metadata = input.quant_metadata + inputs_q.append(input_q) + else: + inputs_q = inputs # Forward through wrapped module accum = self.quantized_forward(*inputs_q) + if self.clip_half_range: + accum = f.relu(accum) + # Re-quantize accumulator to quantized output range out_scale, out_zero_point = self.get_output_quantization_params(accum) requant_scale, requant_zero_point = self.get_accum_to_output_re_quantization_params(out_scale, out_zero_point) @@ -342,18 +410,48 @@ class RangeLinearQuantWrapper(nn.Module): # De-quantize back to FP32 out_f = linear_dequantize(out_q, out_scale, out_zero_point, inplace=True) - return out_f + out_f.quant_metadata = TensorQuantMetadata(out_scale, out_zero_point, self.acts_min_q_val, self.acts_max_q_val) - def get_inputs_quantization_params(self, *inputs): - """ - Calculate input quantization parameters (scale and zero-point) + self.num_forwards += 1 - Should be overridden by all subclasses + return out_f - :param inputs: Current input tensors passed to forward method - :return: Tuple of 2 lists - list of scales per input and list of zero-point per input - """ - raise NotImplementedError + def _prepare_inputs_for_quantization(self, inputs): + for idx, input in enumerate(inputs): + if hasattr(input, 'quant_metadata'): + if idx in self.inputs_quant_settings_overrides: + raise RuntimeError('<{}> Input {}: CONFLICT - Tensor has embedded quantization metadata AND user ' + 'defined input quantization settings'.format(self.distiller_name, idx)) + else: + # Input doesn't have embedded quantization data propagated from a previous layer + # Our options are: + # If user set explicit settings for this input, use those + # OR + # If auto fallback is set, use the output quantization settings + if idx not in self.inputs_quant_settings_overrides and not self.inputs_quant_auto_fallback: + raise RuntimeError('<{}> Input {}: Expected tensor with embedded quantization metadata. Either:\n' + '1. Make sure the previous operation is quantized\n' + '2. Provide explicit input quantization settings\n' + '3. Set inputs_quant_auto_fallback'.format(self.distiller_name, idx)) + if self.preset_act_stats: + input.quant_metadata = self.inputs_quant_metadata_fallback[idx] + else: + if idx in self.inputs_quant_settings_overrides: + q_settings = self.inputs_quant_settings_overrides[idx] + else: + # If we're here then inputs_quant_auto_fallback is set + # if self.num_forwards == 0: + # msglogger.info('<{}> Input {}: No embedded quantization metadata, ' + # 'falling back to output settings'.format(self.distiller_name, idx)) + q_settings = deepcopy(self.output_quant_settings) + q_settings.clip_half_range = False + scale, zp = _get_quant_params_from_tensor(input, q_settings.num_bits, q_settings.quant_mode, + q_settings.clip_mode, q_settings.per_channel, + q_settings.clip_n_stds, q_settings.clip_half_range, + self.scale_approx_mult_bits) + signed = q_settings.quant_mode != LinearQuantMode.ASYMMETRIC_UNSIGNED + min_q_val, max_q_val = get_quantized_range(q_settings.num_bits, signed) + input.quant_metadata = TensorQuantMetadata(scale, zp, min_q_val, max_q_val) def quantized_forward(self, *inputs_q): """ @@ -393,19 +491,26 @@ class RangeLinearQuantWrapper(nn.Module): raise NotImplementedError def extra_repr(self): - tmpstr = 'mode={0}, '.format(str(self.mode).split('.')[1]) - tmpstr += 'num_bits_acts={0}, num_bits_accum={1}, '.format(self.num_bits_acts, self.num_bits_accum) - tmpstr += 'clip_acts={0}, '.format(_enum_to_str(self.clip_acts)) - if self.clip_acts == ClipMode.N_STD: - tmpstr += 'num_stds={} '.format(self.clip_n_stds) - tmpstr += 'scale_approx_mult_bits={}'.format(self.scale_approx_mult_bits) + tmpstr = 'output_quant_settings={0}'.format(self.output_quant_settings) + tmpstr += '\naccum_quant_settings={0}'.format(self.accum_quant_settings) + tmpstr += '\nrequires_quantized_inputs={0}'.format(self.requires_quantized_inputs) + if self.requires_quantized_inputs: + overrides = self.inputs_quant_settings_overrides + tmpstr += '\n inputs_quant_auto_fallback={}'.format(self.inputs_quant_auto_fallback) + tmpstr += ', forced_quant_settings_for_inputs={}'.format( + 'None' if not overrides else list(overrides.keys())) + for idx, qset in overrides.items(): + tmpstr += '\n input_{}_settings={}'.format(idx, qset) + tmpstr += '\nscale_approx_mult_bits={}'.format(self.scale_approx_mult_bits) tmpstr += '\npreset_activation_stats={0}'.format(self.preset_act_stats) if self.preset_act_stats: - for idx, (in_scale, in_zp) in enumerate(zip(self.inputs_scales(), self.inputs_zero_points())): - tmpstr += '\nin_{i}_scale={sc}, in_{i}_zero_point={zp}'.format(i=idx, sc=in_scale.item(), - zp=in_zp.item()) - tmpstr += '\nout_scale={sc}, out_zero_point={zp}'.format(sc=self.output_scale.item(), - zp=self.output_zero_point.item()) + tmpstr += '\n output_scale={0}, output_zero_point={1}'.format(_quant_param_to_str( + self.output_scale), _quant_param_to_str(self.output_zero_point)) + if self.requires_quantized_inputs: + for idx in self.inputs_quant_settings_overrides: + qmd = self.inputs_quant_metadata_fallback[idx] + tmpstr += '\n input_#{0}_scale={1}, input_#{0}_zero_point={2}'.format( + idx, _quant_param_to_str(qmd.scale), _quant_param_to_str(qmd.zero_point)) return tmpstr @@ -444,29 +549,35 @@ class RangeLinearQuantParamLayerWrapper(RangeLinearQuantWrapper): per_channel_wts (bool): Enable quantization of weights using separate quantization parameters per output channel activation_stats (dict): See RangeLinearQuantWrapper - clip_n_stds (int): See RangeLinearQuantWrapper + clip_n_stds (float): See RangeLinearQuantWrapper clip_half_range (bool) : See RangeLinearQuantWrapper scale_approx_mult_bits (int): See RangeLinearQuantWrapper """ def __init__(self, wrapped_module, num_bits_acts, num_bits_params, num_bits_accum=32, mode=LinearQuantMode.SYMMETRIC, clip_acts=ClipMode.NONE, per_channel_wts=False, activation_stats=None, - clip_n_stds=None, clip_half_range=False, scale_approx_mult_bits=None): + clip_n_stds=None, clip_half_range=False, scale_approx_mult_bits=None, + input_overrides=None, inputs_quant_auto_fallback=False): super(RangeLinearQuantParamLayerWrapper, self).__init__(wrapped_module, num_bits_acts, num_bits_accum, mode, - clip_acts, activation_stats, clip_n_stds, - clip_half_range, scale_approx_mult_bits) + clip_acts, activation_stats, clip_n_stds, clip_half_range, + scale_approx_mult_bits, + input_overrides=input_overrides, + requires_quantized_inputs=True, + inputs_quant_auto_fallback=inputs_quant_auto_fallback) if not isinstance(wrapped_module, (nn.Conv2d, nn.Conv3d, nn.Linear)): raise ValueError(self.__class__.__name__ + ' can wrap only Conv2D, Conv3D and Linear modules') - self.num_bits_params = num_bits_params - self.per_channel_wts = per_channel_wts + self.wts_quant_settings = QuantSettings(num_bits_params, mode, ClipMode.NONE, None, False, per_channel_wts) self.params_min_q_val, self.params_max_q_val = get_quantized_range( - num_bits_params, signed=mode != LinearQuantMode.ASYMMETRIC_UNSIGNED) + self.wts_quant_settings.num_bits, + self.wts_quant_settings.quant_mode != LinearQuantMode.ASYMMETRIC_UNSIGNED) # Quantize weights - overwrite FP32 weights - w_scale, w_zero_point = _get_quant_params_from_tensor(wrapped_module.weight, num_bits_params, self.mode, - per_channel=per_channel_wts) + w_scale, w_zero_point = _get_quant_params_from_tensor(wrapped_module.weight, + self.wts_quant_settings.num_bits, + self.wts_quant_settings.quant_mode, + per_channel=self.wts_quant_settings.per_channel) self.register_buffer('w_scale', w_scale) self.register_buffer('w_zero_point', w_zero_point) @@ -477,27 +588,25 @@ class RangeLinearQuantParamLayerWrapper(RangeLinearQuantWrapper): device = self.w_scale.device if self.preset_act_stats: - self.in_0_scale = self.in_0_scale.to(device) - self.register_buffer('accum_scale', self.in_0_scale * self.w_scale) - if self.per_channel_wts: - self.accum_scale = self.accum_scale.squeeze(dim=-1) + t = torch.zeros_like(self.w_scale) + if self.wts_quant_settings.per_channel: + t = t.squeeze(dim=-1) + self.register_buffer('accum_scale', t) else: - self.accum_scale = 1 + self.accum_scale = torch.ones(1).to(device) # Quantize bias self.has_bias = hasattr(wrapped_module, 'bias') and wrapped_module.bias is not None - if self.has_bias: - if self.preset_act_stats: - linear_quantize_clamp(wrapped_module.bias.data, self.accum_scale.squeeze(), 0, - self.accum_min_q_val, self.accum_max_q_val, inplace=True) - else: - b_scale, b_zero_point = _get_quant_params_from_tensor(wrapped_module.bias, num_bits_params, self.mode) - self.register_buffer('b_scale', b_scale) - self.register_buffer('b_zero_point', b_zero_point) - base_b_q = linear_quantize_clamp(wrapped_module.bias.data, self.b_scale, self.b_zero_point, - self.params_min_q_val, self.params_max_q_val) - # Dynamic ranges - save in auxiliary buffer, requantize each time based on dynamic input scale factor - self.register_buffer('base_b_q', base_b_q) + if self.has_bias and not self.preset_act_stats: + b_scale, b_zero_point = _get_quant_params_from_tensor(wrapped_module.bias, + self.wts_quant_settings.num_bits, + self.wts_quant_settings.quant_mode) + self.register_buffer('b_scale', b_scale) + self.register_buffer('b_zero_point', b_zero_point) + base_b_q = linear_quantize_clamp(wrapped_module.bias.data, self.b_scale, self.b_zero_point, + self.params_min_q_val, self.params_max_q_val) + # Dynamic ranges - save in auxiliary buffer, requantize each time based on dynamic input scale factor + self.register_buffer('base_b_q', base_b_q) # A flag indicating that the simulated quantized weights are pre-shifted. for faster performance. # In the first forward pass - `w_zero_point` is added into the weights, to allow faster inference, @@ -513,21 +622,26 @@ class RangeLinearQuantParamLayerWrapper(RangeLinearQuantWrapper): self.is_simulated_quant_weight_shifted.sub_(1) # i.e. is_simulated_quant_weight_shifted = False return super(RangeLinearQuantParamLayerWrapper, self).state_dict(destination, prefix, keep_vars) - def get_inputs_quantization_params(self, input): - if not self.preset_act_stats: - self.in_0_scale, self.in_0_zero_point = _get_quant_params_from_tensor( - input, self.num_bits_acts, self.mode, clip=self.clip_acts, - num_stds=self.clip_n_stds, scale_approx_mult_bits=self.scale_approx_mult_bits) - return [self.in_0_scale], [self.in_0_zero_point] - def quantized_forward(self, input_q): # See class documentation for quantized calculation details. - if not self.preset_act_stats: - self.accum_scale = self.in_0_scale * self.w_scale - if self.per_channel_wts: - self.accum_scale = self.accum_scale.squeeze(dim=-1) + def get_accum_scale(input_q): + accum_scale = input_q.quant_metadata.scale * self.w_scale + if self.wts_quant_settings.per_channel: + accum_scale = accum_scale.squeeze(dim=-1) + if self.scale_approx_mult_bits: + accum_scale = approx_scale_as_mult_and_shift(accum_scale, self.scale_approx_mult_bits) + return accum_scale + if self.preset_act_stats: + if self.num_forwards == 0: + self.accum_scale += get_accum_scale(input_q) + if self.has_bias: + # Requantize bias to accumulator scale "permanently" + linear_quantize_clamp(self.wrapped_module.bias.data, self.accum_scale.squeeze(), 0, + self.accum_min_q_val, self.accum_max_q_val, inplace=True) + else: + self.accum_scale = get_accum_scale(input_q) if self.has_bias: # Re-quantize bias to match x * w scale: b_q' = (in_scale * w_scale / b_scale) * (b_q + b_zero_point) bias_requant_scale = self.accum_scale.squeeze() / self.b_scale @@ -546,13 +660,13 @@ class RangeLinearQuantParamLayerWrapper(RangeLinearQuantWrapper): # to the input and weights and pass those to the wrapped model. Functionally, since at this point we're # dealing solely with integer values, the results are the same either way. - if self.mode != LinearQuantMode.SYMMETRIC and not self.is_simulated_quant_weight_shifted: + if self.output_quant_settings.quant_mode != LinearQuantMode.SYMMETRIC and not self.is_simulated_quant_weight_shifted: # We "store" the w_zero_point inside our wrapped module's weights to # improve performance on inference. self.wrapped_module.weight.data += self.w_zero_point - self.is_simulated_quant_weight_shifted.add_(1) # i.e. is_simulated_quant_weight_shifted = True + self.is_simulated_quant_weight_shifted.add_(1) # i.e. is_simulated_quant_weight_shifted = True - input_q += self.in_0_zero_point + input_q += input_q.quant_metadata.zero_point accum = self.wrapped_module.forward(input_q) clamp(accum.data, self.accum_min_q_val, self.accum_max_q_val, inplace=True) @@ -563,8 +677,10 @@ class RangeLinearQuantParamLayerWrapper(RangeLinearQuantWrapper): return self.output_scale, self.output_zero_point y_f = accumulator / self.accum_scale - return _get_quant_params_from_tensor(y_f, self.num_bits_acts, self.mode, clip=self.clip_acts, - num_stds=self.clip_n_stds, + q_set = self.output_quant_settings + return _get_quant_params_from_tensor(y_f, q_set.num_bits, q_set.quant_mode, + clip=q_set.clip_mode, num_stds=q_set.clip_n_stds, + half_range=q_set.clip_half_range, scale_approx_mult_bits=self.scale_approx_mult_bits) def get_accum_to_output_re_quantization_params(self, output_scale, output_zero_point): @@ -574,28 +690,13 @@ class RangeLinearQuantParamLayerWrapper(RangeLinearQuantWrapper): return requant_scale, output_zero_point def extra_repr(self): - tmpstr = 'mode={0}, '.format(str(self.mode).split('.')[1]) - tmpstr += 'num_bits_acts={0}, num_bits_params={1}, num_bits_accum={2}, '.format(self.num_bits_acts, - self.num_bits_params, - self.num_bits_accum) - tmpstr += 'clip_acts={0}, '.format(_enum_to_str(self.clip_acts)) - if self.clip_acts == ClipMode.N_STD: - tmpstr += 'num_stds={} '.format(self.clip_n_stds) - tmpstr += 'per_channel_wts={}, scale_approx_mult_bits={}'.format(self.per_channel_wts, - self.scale_approx_mult_bits) - tmpstr += '\npreset_activation_stats={0}'.format(self.preset_act_stats) - if self.per_channel_wts: - tmpstr += '\nw_scale=PerCh, w_zero_point=PerCh' - else: - tmpstr += '\nw_scale={0:.4f}, w_zero_point={1:.4f}'.format(self.w_scale.item(), self.w_zero_point.item()) - if self.preset_act_stats: - tmpstr += '\nin_scale={0:.4f}, in_zero_point={1:.4f}'.format(self.in_0_scale.item(), - self.in_0_zero_point.item()) - tmpstr += '\nout_scale={0:.4f}, out_zero_point={1:.4f}'.format(self.output_scale.item(), - self.output_zero_point.item()) - elif self.has_bias: - tmpstr += '\nbase_b_scale={0:.4f}, base_b_zero_point={1:.4f}'.format(self.b_scale.item(), - self.b_zero_point.item()) + tmpstr = 'weights_quant_settings={0}\n'.format(self.wts_quant_settings) + tmpstr += super(RangeLinearQuantParamLayerWrapper, self).extra_repr() + tmpstr += '\nweights_scale={0}, weights_zero_point={1}'.format(_quant_param_to_str(self.w_scale), + _quant_param_to_str(self.w_zero_point)) + if not self.preset_act_stats and self.has_bias: + tmpstr += '\nbase_bias_scale={0}, base_bias_zero_point={1}'.format(_quant_param_to_str(self.b_scale), + _quant_param_to_str(self.b_zero_point)) return tmpstr @@ -625,32 +726,24 @@ class RangeLinearQuantMatmulWrapper(RangeLinearQuantWrapper): scale_approx_mult_bits (int): See RangeLinearQuantWrapper """ def __init__(self, wrapped_module, num_bits_acts, num_bits_accum=32, - mode=LinearQuantMode.SYMMETRIC, clip_acts=ClipMode.NONE, activation_stats=None, - clip_n_stds=None, clip_half_range=False, scale_approx_mult_bits=None): + mode=LinearQuantMode.SYMMETRIC, clip_acts=ClipMode.NONE, activation_stats=None, + clip_n_stds=None, clip_half_range=False, scale_approx_mult_bits=None, + input_overrides=None, inputs_quant_auto_fallback=False): super(RangeLinearQuantMatmulWrapper, self).__init__(wrapped_module, num_bits_acts, num_bits_accum, mode, clip_acts, activation_stats, clip_n_stds, clip_half_range, - scale_approx_mult_bits) + scale_approx_mult_bits, + input_overrides=input_overrides, + requires_quantized_inputs=True, + inputs_quant_auto_fallback=inputs_quant_auto_fallback) if not isinstance(wrapped_module, (distiller.modules.Matmul, distiller.modules.BatchMatmul)): raise ValueError(self.__class__.__name__ + ' can wrap only Matmul modules') - if self.preset_act_stats: - self.register_buffer('accum_scale', self.in_0_scale * self.in_1_scale) - else: - self.accum_scale = 1 - - def get_inputs_quantization_params(self, input0, input1): - if not self.preset_act_stats: - self.in_0_scale, self.in_0_zero_point = _get_quant_params_from_tensor( - input0, self.num_bits_acts, self.mode, clip=self.clip_acts, - num_stds=self.clip_n_stds, scale_approx_mult_bits=self.scale_approx_mult_bits) - self.in_1_scale, self.in_1_zero_point = _get_quant_params_from_tensor( - input0, self.num_bits_acts, self.mode, clip=self.clip_acts, - num_stds=self.clip_n_stds, scale_approx_mult_bits=self.scale_approx_mult_bits) - return [self.in_0_scale, self.in_1_scale], [self.in_0_zero_point, self.in_1_zero_point] + self.accum_scale = 1 def quantized_forward(self, input0_q, input1_q): - accum = self.wrapped_module.forward(input0_q + self.in_0_zero_point, - input1_q + self.in_1_zero_point) + self.accum_scale = input0_q.quant_metadata.scale * input1_q.quant_metadata.scale + accum = self.wrapped_module.forward(input0_q + input0_q.quant_metadata.zero_point, + input1_q + input1_q.quant_metadata.zero_point) clamp(accum.data, self.accum_min_q_val, self.accum_max_q_val, inplace=True) return accum @@ -659,8 +752,10 @@ class RangeLinearQuantMatmulWrapper(RangeLinearQuantWrapper): return self.output_scale, self.output_zero_point y_f = accumulator / self.accum_scale - return _get_quant_params_from_tensor(y_f, self.num_bits_acts, self.mode, clip=self.clip_acts, - num_stds=self.clip_n_stds, + q_set = self.output_quant_settings + return _get_quant_params_from_tensor(y_f, q_set.num_bits, q_set.quant_mode, + clip=q_set.clip_mode, num_stds=q_set.clip_n_stds, + half_range=q_set.clip_half_range, scale_approx_mult_bits=self.scale_approx_mult_bits) def get_accum_to_output_re_quantization_params(self, output_scale, output_zero_point): @@ -669,25 +764,6 @@ class RangeLinearQuantMatmulWrapper(RangeLinearQuantWrapper): requant_scale = approx_scale_as_mult_and_shift(requant_scale, self.scale_approx_mult_bits) return requant_scale, output_zero_point - def extra_repr(self): - tmpstr = 'mode={0}, '.format(str(self.mode).split('.')[1]) - tmpstr += 'num_bits_acts={0}, num_bits_accum={1}, '.format(self.num_bits_acts, self.num_bits_accum) - tmpstr += 'clip_acts={0}, '.format(_enum_to_str(self.clip_acts)) - if self.clip_acts == ClipMode.N_STD: - tmpstr += 'num_stds={} '.format(self.clip_n_stds) - tmpstr += 'scale_approx_mult_bits={}'.format(self.scale_approx_mult_bits) - tmpstr += '\npreset_activation_stats={0}'.format(self.preset_act_stats) - if self.preset_act_stats: - tmpstr += '\nin_0_scale={0:.4f}, in_0_zero_point={1:.4f}'.format(self.in_0_scale.item(), - self.in_0_zero_point.item()) - - tmpstr += '\nin_1_scale={0:.4f}, in_1_zero_point={1:.4f}'.format(self.in_1_scale.item(), - self.in_1_zero_point.item()) - - tmpstr += '\nout_scale={0:.4f}, out_zero_point={1:.4f}'.format(self.output_scale.item(), - self.output_zero_point.item()) - return tmpstr - class NoStatsError(NotImplementedError): pass @@ -695,7 +771,8 @@ class NoStatsError(NotImplementedError): class RangeLinearQuantConcatWrapper(RangeLinearQuantWrapper): def __init__(self, wrapped_module, num_bits_acts, mode=LinearQuantMode.SYMMETRIC, clip_acts=ClipMode.NONE, - activation_stats=None, clip_n_stds=None, clip_half_range=False, scale_approx_mult_bits=None): + activation_stats=None, clip_n_stds=None, clip_half_range=False, scale_approx_mult_bits=None, + input_overrides=None, inputs_quant_auto_fallback=False): if not isinstance(wrapped_module, distiller.modules.Concat): raise ValueError(self.__class__.__name__ + ' can only wrap distiller.modules.Concat modules') @@ -706,33 +783,18 @@ class RangeLinearQuantConcatWrapper(RangeLinearQuantWrapper): super(RangeLinearQuantConcatWrapper, self).__init__(wrapped_module, num_bits_acts, mode=mode, clip_acts=clip_acts, activation_stats=activation_stats, clip_n_stds=clip_n_stds, clip_half_range=clip_half_range, - scale_approx_mult_bits=scale_approx_mult_bits) - - if self.preset_act_stats: - # For concatenation to make sense, we need to match all the inputs' scales, so we - # set a re-scale factor based on the preset output scale - for idx, in_scale in enumerate(self.inputs_scales()): - requant_scale = self.output_scale / in_scale - if self.scale_approx_mult_bits is not None: - requant_scale = approx_scale_as_mult_and_shift(requant_scale, self.scale_approx_mult_bits) - self.register_buffer('in_{0}_requant_scale'.format(idx), requant_scale) - - def inputs_requant_scales(self): - if not self.preset_act_stats: - raise RuntimeError('Input quantization parameter iterators only available when activation stats were given') - for idx in range(self.num_inputs): - name = 'in_{0}_requant_scale'.format(idx) - yield getattr(self, name) - - def get_inputs_quantization_params(self, *inputs): - return self.inputs_scales(), self.inputs_zero_points() + scale_approx_mult_bits=scale_approx_mult_bits, + input_overrides=input_overrides, + requires_quantized_inputs=True, + inputs_quant_auto_fallback=inputs_quant_auto_fallback) def quantized_forward(self, *inputs_q): - # Re-quantize all inputs based to the same range (the output range) - inputs_re_q = [linear_quantize_clamp(input_q + zp, requant_scale, self.output_zero_point, - self.acts_min_q_val, self.acts_max_q_val, inplace=False) - for input_q, requant_scale, zp in zip(inputs_q, self.inputs_requant_scales(), - self.inputs_zero_points())] + # For concatenation to make sense input scales need to match, so we re-quantize all inputs + # based on the output scale + inputs_re_q = [linear_quantize_clamp(input_q + input_q.quant_metadata.zero_point, + self.output_scale / input_q.quant_metadata.scale, 0., + self.accum_min_q_val, self.accum_max_q_val, inplace=False) + for input_q in inputs_q] return self.wrapped_module(*inputs_re_q) def get_output_quantization_params(self, accumulator): @@ -740,12 +802,13 @@ class RangeLinearQuantConcatWrapper(RangeLinearQuantWrapper): def get_accum_to_output_re_quantization_params(self, output_scale, output_zero_point): # Nothing to do here, since we already re-quantized in quantized_forward prior to the actual concatenation - return 1., 0. + return 1., self.output_zero_point class RangeLinearQuantEltwiseAddWrapper(RangeLinearQuantWrapper): def __init__(self, wrapped_module, num_bits_acts, mode=LinearQuantMode.SYMMETRIC, clip_acts=ClipMode.NONE, - activation_stats=None, clip_n_stds=None, clip_half_range=False, scale_approx_mult_bits=None): + activation_stats=None, clip_n_stds=None, clip_half_range=False, scale_approx_mult_bits=None, + input_overrides=None, inputs_quant_auto_fallback=False): if not isinstance(wrapped_module, distiller.modules.EltwiseAdd): raise ValueError(self.__class__.__name__ + ' can only wrap distiller.modules.EltwiseAdd modules') @@ -755,36 +818,18 @@ class RangeLinearQuantEltwiseAddWrapper(RangeLinearQuantWrapper): super(RangeLinearQuantEltwiseAddWrapper, self).__init__(wrapped_module, num_bits_acts, mode=mode, clip_acts=clip_acts, activation_stats=activation_stats, - clip_n_stds=clip_n_stds, - clip_half_range=clip_half_range, - scale_approx_mult_bits=scale_approx_mult_bits) - - if self.preset_act_stats: - # For addition to make sense, all input scales must match. So we set a re-scale factor according - # to the preset output scale - requant_scales = [self.output_scale / in_scale for in_scale in self.inputs_scales()] - if scale_approx_mult_bits is not None: - requant_scales = [approx_scale_as_mult_and_shift(requant_scale, scale_approx_mult_bits) - for requant_scale in requant_scales] - for idx, requant_scale in enumerate(requant_scales): - self.register_buffer('in_{0}_requant_scale'.format(idx), requant_scale) - - def inputs_requant_scales(self): - if not self.preset_act_stats: - raise RuntimeError('Input quantization parameter iterators only available when activation stats were given') - for idx in range(self.num_inputs): - name = 'in_{0}_requant_scale'.format(idx) - yield getattr(self, name) - - def get_inputs_quantization_params(self, *inputs): - return self.inputs_scales(), self.inputs_zero_points() + clip_n_stds=clip_n_stds, clip_half_range=clip_half_range, + scale_approx_mult_bits=scale_approx_mult_bits, + input_overrides=input_overrides, + requires_quantized_inputs=True, + inputs_quant_auto_fallback=inputs_quant_auto_fallback) def quantized_forward(self, *inputs_q): - # Re-scale inputs to the accumulator range - inputs_re_q = [linear_quantize_clamp(input_q + zp, requant_scale, 0, + # Re-scale inputs to the accumulator scale + inputs_re_q = [linear_quantize_clamp(input_q + input_q.quant_metadata.zero_point, + self.output_scale / input_q.quant_metadata.scale, 0, self.accum_min_q_val, self.accum_max_q_val, inplace=False) - for input_q, requant_scale, zp in zip(inputs_q, self.inputs_requant_scales(), - self.inputs_zero_points())] + for input_q in inputs_q] accum = self.wrapped_module(*inputs_re_q) clamp(accum.data, self.accum_min_q_val, self.accum_max_q_val, inplace=True) @@ -799,7 +844,8 @@ class RangeLinearQuantEltwiseAddWrapper(RangeLinearQuantWrapper): class RangeLinearQuantEltwiseMultWrapper(RangeLinearQuantWrapper): def __init__(self, wrapped_module, num_bits_acts, mode=LinearQuantMode.SYMMETRIC, clip_acts=ClipMode.NONE, - activation_stats=None, clip_n_stds=None, clip_half_range=False, scale_approx_mult_bits=None): + activation_stats=None, clip_n_stds=None, clip_half_range=False, scale_approx_mult_bits=None, + input_overrides=None, inputs_quant_auto_fallback=False): if not isinstance(wrapped_module, distiller.modules.EltwiseMult): raise ValueError(self.__class__.__name__ + ' can only wrap distiller.modules.EltwiseMult modules') @@ -809,20 +855,19 @@ class RangeLinearQuantEltwiseMultWrapper(RangeLinearQuantWrapper): super(RangeLinearQuantEltwiseMultWrapper, self).__init__(wrapped_module, num_bits_acts, mode=mode, clip_acts=clip_acts, activation_stats=activation_stats, - clip_n_stds=clip_n_stds, - clip_half_range=clip_half_range, - scale_approx_mult_bits=scale_approx_mult_bits) - - if self.preset_act_stats: - self.register_buffer('accum_scale', reduce(lambda x, y: x * y, self.inputs_scales())) - - def get_inputs_quantization_params(self, *inputs): - return self.inputs_scales(), self.inputs_zero_points() + clip_n_stds=clip_n_stds, clip_half_range=clip_half_range, + scale_approx_mult_bits=scale_approx_mult_bits, + input_overrides=input_overrides, + requires_quantized_inputs=True, + inputs_quant_auto_fallback=inputs_quant_auto_fallback) + self.accum_scale = 1 def quantized_forward(self, *inputs_q): - if self.mode != LinearQuantMode.SYMMETRIC: - for input_q, zp in zip(inputs_q, self.inputs_zero_points()): - input_q += zp + input_scales = [input_q.quant_metadata.scale for input_q in inputs_q] + self.accum_scale = reduce(lambda x, y: x * y, input_scales) + + for input_q in inputs_q: + input_q += input_q.quant_metadata.zero_point accum = self.wrapped_module(*inputs_q) clamp(accum.data, self.accum_min_q_val, self.accum_max_q_val, inplace=True) @@ -839,26 +884,28 @@ class RangeLinearQuantEltwiseMultWrapper(RangeLinearQuantWrapper): return requant_scale, output_zero_point -class FP16Wrapper(nn.Module): +class FPWrapper(nn.Module): """ A wrapper that replaces a module with a half precision version. - Args: module (nn.Module): The module to be replaced. - convert_input (:obj:`bool`, optional): Specifies whether an input conversion - to fp16 is required for forward. Default: True. - return_fp32 (:obj:`bool`, optional): Specifies whether the output needs + precision (Union[str, int]): the floating point precision to use. Either 16/32/64. + convert_input (bool): Specifies whether an input conversion + to module precision is required for forward. Default: True. + return_fp32 (bool): Specifies whether the output needs to be converted back to fp32. Default: True. """ - def __init__(self, module: nn.Module, convert_input=True, return_fp32=True): - super(FP16Wrapper, self).__init__() - self.wrapped_module = module.half() + def __init__(self, module: nn.Module, precision, convert_input=True, return_fp32=True): + super(FPWrapper, self).__init__() + precision = str(precision) + self.dtype = {'16': torch.float16, '32': torch.float32, '64': torch.float64}[precision] + self.wrapped_module = module.to(self.dtype) self.return_fp32 = return_fp32 - self.convert_input_fp16 = convert_input + self.convert_input = convert_input def forward(self, *input): - if self.convert_input_fp16: - input = distiller.convert_tensors_recursively_to(input, dtype=torch.float16) + if self.convert_input: + input = distiller.convert_tensors_recursively_to(input, dtype=self.dtype) result = self.wrapped_module(*input) if self.return_fp32: @@ -867,6 +914,11 @@ class FP16Wrapper(nn.Module): return result +class FP16Wrapper(FPWrapper): + def __init__(self, module, convert_input=True, return_fp32=True): + super(FP16Wrapper, self).__init__(module, 16, convert_input, return_fp32) + + class RangeLinearEmbeddingWrapper(nn.Module): def __init__(self, wrapped_module, num_bits, mode=LinearQuantMode.SYMMETRIC, stats=None): if not isinstance(wrapped_module, nn.Embedding): @@ -888,20 +940,69 @@ class RangeLinearEmbeddingWrapper(nn.Module): self.register_buffer('w_zero_point', w_zero_point.to(device)) linear_quantize_clamp(wrapped_module.weight.data, self.w_scale, self.w_zero_point, self.min_q_val, self.max_q_val, inplace=True) - + self.quant_metadata = TensorQuantMetadata(self.w_scale, self.w_zero_point, self.min_q_val, self.max_q_val) self.wrapped_module = wrapped_module def forward(self, input): out_q = self.wrapped_module(input) out_f = linear_dequantize(out_q, self.w_scale, self.w_zero_point, inplace=True) + out_f.quant_metadata = self.quant_metadata return out_f +class RangeLinearFakeQuantWrapper(RangeLinearQuantWrapper): + def __init__(self, wrapped_module, num_bits_acts, mode=LinearQuantMode.SYMMETRIC, clip_acts=ClipMode.NONE, + activation_stats=None, clip_n_stds=None, clip_half_range=False, scale_approx_mult_bits=None, + fpq_module=None): + super(RangeLinearFakeQuantWrapper, self).__init__(wrapped_module, num_bits_acts, mode=mode, + clip_acts=clip_acts, activation_stats=activation_stats, + clip_n_stds=clip_n_stds, clip_half_range=clip_half_range, + scale_approx_mult_bits=scale_approx_mult_bits, + requires_quantized_inputs=False) + self.fpq_module = str(fpq_module) if fpq_module else None + self.dtype = torch.float + if self.fpq_module: + self.dtype = {'16': torch.half, '32': torch.float, '64': torch.double}[self.fpq_module] + self.wrapped_module.to(self.dtype) + + def quantized_forward(self, *inputs_q): + inputs_q = distiller.convert_tensors_recursively_to(inputs_q, dtype=self.dtype) + outputs = self.wrapped_module(*inputs_q) + return distiller.convert_tensors_recursively_to(outputs, dtype=self.dtype) + + def get_output_quantization_params(self, accumulator): + if self.preset_act_stats: + return self.output_scale, self.output_zero_point + else: + q_set = self.output_quant_settings + return _get_quant_params_from_tensor(accumulator, q_set.num_bits, q_set.quant_mode, q_set.clip_mode, + q_set.per_channel, q_set.clip_n_stds,q_set.clip_half_range, + self.scale_approx_mult_bits) + + def get_accum_to_output_re_quantization_params(self, output_scale, output_zero_point): + return output_scale, output_zero_point + + class PostTrainLinearQuantizer(Quantizer): """ Applies range-based linear quantization to a model. This quantizer is expected to be executed at evaluation only, on a pre-trained model - Currently, the following Modules are supported: torch.nn.Conv2d, torch.nn.Linear + + The following modules / operations have dedicated implementations which consider quantization: + * torch.nn.Conv2d/Conv3d + * torch.nn.Linear + * torch.nn.Embedding + * distiller.modules.Concat + * distiller.modules.EltwiseAdd + * distiller.modules.EltwiseMult + * distiller.modules.Matmul + * distiller.modules.BatchMatmul + An existing module will need likely need to be modified to use the 'distiller.modules.*' modules. This needs to + be done BEFORE creating the quantizer. See the docs for more details: + https://nervanasystems.github.io/distiller/prepare_model_quant.html + + Any leaf module not in the list above will be "fake-quantized". That is - the floating-point module will be + executed (FP64/32/16 can be specified with the fpq_module argument), and its output will be quantized. Args: model (torch.nn.Module): Model to be quantized @@ -916,23 +1017,34 @@ class PostTrainLinearQuantizer(Quantizer): The dict should be in the format exported by distiller.data_loggers.QuantCalibrationStatsCollector. If None then parameters are calculated dynamically. fp16 (bool): Set to True to convert modules to half precision. + WARNING - this argument is deprecated, use instead the argument `fpq_module` clip_n_stds (float): When clip_acts == ClipMode.N_STD, this is the number of standard deviations to use + clip_half_range (bool): When clip_acts is scale_approx_mult_bits (int): If not None, scale factors will be approximated using an integer multiplication followed by a bit-wise shift. This eliminates floating-point scale factors, replacing them with integer calculations. If None, scale factors will be kept in their original FP32 values. + inputs_quant_auto_fallback (bool): Enabled by default. + See <distiller_root>/examples/post_train_quant/resnet18_imagenet_post_train_input_overrides.yaml + For details what this does and how to override it. + fpq_module (Union[int, str]): use the modules in floating point mode and only quantize their outputs. + takes the values (16, 32, 64) only, this will use RangeLinearFakeQuantWrapper. Note: - If fp16 is set to True, all the layers (except those overridden in `overrides`) will be converted - to half precision, regardless of bits_activations/parameters/accum. + If fpq_module is set, all the layers (except those overridden in `overrides`) will be converted + to the set floating point precision, regardless of bits_activations/parameters/accum. """ def __init__(self, model, bits_activations=8, bits_parameters=8, bits_accum=32, overrides=None, mode=LinearQuantMode.SYMMETRIC, clip_acts=ClipMode.NONE, - per_channel_wts=False, model_activation_stats=None, fp16=False, clip_n_stds=None, - scale_approx_mult_bits=None): + per_channel_wts=False, model_activation_stats=None, fp16=False, + clip_n_stds=None, clip_half_range=False, + scale_approx_mult_bits=None, inputs_quant_auto_fallback=True, + fpq_module=None): + overrides_bkp = deepcopy(overrides) super(PostTrainLinearQuantizer, self).__init__(model, bits_activations=bits_activations, bits_weights=bits_parameters, bits_bias=bits_accum, overrides=overrides, train_with_fp_copy=False) - + if fp16 and str(fpq_module) not in ('16', 'None'): + raise ValueError('Conflict - fp16 set to true and fpq_module set to other than 16.') mode = verify_quant_mode(mode) clip_acts = verify_clip_mode(clip_acts) if clip_acts == ClipMode.N_STD and clip_n_stds is None: @@ -964,39 +1076,71 @@ class PostTrainLinearQuantizer(Quantizer): 'mode': str(mode).split('.')[1], 'clip_acts': _enum_to_str(clip_acts), 'clip_n_stds': clip_n_stds, + 'clip_half_range': clip_half_range, 'per_channel_wts': per_channel_wts, - 'fp16': fp16, - 'scale_approx_mult_bits': scale_approx_mult_bits}} + 'scale_approx_mult_bits': scale_approx_mult_bits, + 'inputs_quant_auto_fallback': inputs_quant_auto_fallback, + 'fpq_module': fpq_module, + 'model_activation_stats': model_activation_stats, + 'overrides': overrides_bkp}} def replace_param_layer(module, name, qbits_map, per_channel_wts=per_channel_wts, mode=mode, fp16=fp16, scale_approx_mult_bits=scale_approx_mult_bits, - clip_acts=clip_acts, clip_half_range=False, clip_n_stds=clip_n_stds): + clip_acts=clip_acts, clip_n_stds=clip_n_stds, clip_half_range=clip_half_range, + input_overrides=None, fpq_module=fpq_module, fake=False): if fp16: - return FP16Wrapper(module) - - # TODO: Try auto-detecting when clip_half_range is needed - # instead of having the user pass it as a parameter (same for replace_non_param_layer) + warnings.warn("Argument 'fp16' is deprecated. Please use 'fpq_module'(=16/32/64) argument.", + DeprecationWarning) + fpq_module = fpq_module or 16 norm_name = distiller.utils.normalize_module_name(name) clip_acts = verify_clip_mode(clip_acts) + if fpq_module: + if not fake: + return FPWrapper(module, fpq_module) + else: + return RangeLinearFakeQuantWrapper(module, qbits_map[name].acts, mode=mode, clip_acts=clip_acts, + activation_stats=self.model_activation_stats.get(norm_name, + None), + clip_n_stds=clip_n_stds, clip_half_range=clip_half_range, + scale_approx_mult_bits=scale_approx_mult_bits, + fpq_module=fpq_module) + return RangeLinearQuantParamLayerWrapper(module, qbits_map[name].acts, qbits_map[name].wts, num_bits_accum=self.bits_accum, mode=mode, clip_acts=clip_acts, per_channel_wts=per_channel_wts, activation_stats=self.model_activation_stats.get(norm_name, None), clip_n_stds=clip_n_stds, clip_half_range=clip_half_range, - scale_approx_mult_bits=scale_approx_mult_bits) + scale_approx_mult_bits=scale_approx_mult_bits, + input_overrides=input_overrides, + inputs_quant_auto_fallback=inputs_quant_auto_fallback) def replace_non_param_layer(wrapper_type, module, name, qbits_map, fp16=fp16, scale_approx_mult_bits=scale_approx_mult_bits, - clip_acts=clip_acts, clip_n_stds=clip_n_stds, clip_half_range=False): - if fp16: - return FP16Wrapper(module) + clip_acts=clip_acts, clip_n_stds=clip_n_stds, clip_half_range=clip_half_range, + input_overrides=None, inputs_quant_auto_fallback=inputs_quant_auto_fallback, + fpq_module=fpq_module, fake=False): norm_name = distiller.utils.normalize_module_name(name) clip_acts = verify_clip_mode(clip_acts) + if fp16: + warnings.warn("Argument 'fp16' is deprecated. Please use 'fpq_module'(=16/32/64) argument.", + DeprecationWarning) + fpq_module = fpq_module or 16 + if fpq_module: + if fake: + return RangeLinearFakeQuantWrapper(module, qbits_map[name].acts, mode=mode, clip_acts=clip_acts, + activation_stats=self.model_activation_stats.get(norm_name, None), + clip_n_stds=clip_n_stds, clip_half_range=clip_half_range, + scale_approx_mult_bits=scale_approx_mult_bits, + fpq_module=fpq_module) + else: + return FPWrapper(module, fpq_module) try: return wrapper_type(module, qbits_map[name].acts, mode=mode, clip_acts=clip_acts, activation_stats=self.model_activation_stats.get(norm_name, None), - clip_n_stds=clip_n_stds, clip_half_range=clip_half_range, - scale_approx_mult_bits=scale_approx_mult_bits) + clip_n_stds=clip_n_stds, clip_half_range=clip_half_range, + scale_approx_mult_bits=scale_approx_mult_bits, + input_overrides=input_overrides, + inputs_quant_auto_fallback=inputs_quant_auto_fallback) except NoStatsError: msglogger.warning('WARNING: {0} - quantization of {1} without stats not supported. ' 'Keeping the original FP32 module'.format(name, module.__class__.__name__)) @@ -1009,6 +1153,31 @@ class PostTrainLinearQuantizer(Quantizer): return RangeLinearEmbeddingWrapper(module, qbits_map[name].wts, mode=mode, stats=self.model_activation_stats.get(norm_name, None)) + def replace_fake_quant(module, name, qbits_map, fp16=fp16, + clip_acts=clip_acts, clip_n_stds=clip_n_stds, clip_half_range=clip_half_range, + scale_approx_mult_bits=scale_approx_mult_bits, fpq_module=fpq_module, fake=True, + make_identity=False): + if isinstance(module, (nn.ReLU, nn.ReLU6)) and make_identity: + named_modules = OrderedDict(self.model.named_modules()) + pred = self.adjacency_map[name].predecessors[0].name + if isinstance(named_modules[pred], RangeLinearQuantWrapper): + return nn.Identity() + norm_name = distiller.utils.normalize_module_name(name) + clip_acts = verify_clip_mode(clip_acts) + if distiller.has_children(module): + return module + if fp16: + warnings.warn("Argument 'fp16' is deprecated. Please use 'fpq_module'(=16/32/64) argument.", + DeprecationWarning) + fpq_module = 16 + if not fake: + return FPWrapper(module, fpq_module) + return RangeLinearFakeQuantWrapper(module, qbits_map[name].acts, mode=mode, clip_acts=clip_acts, + activation_stats=self.model_activation_stats.get(norm_name, None), + clip_n_stds=clip_n_stds, clip_half_range=clip_half_range, + scale_approx_mult_bits=scale_approx_mult_bits, + fpq_module=fpq_module) + self.clip_acts = clip_acts self.clip_n_stds = clip_n_stds self.model_activation_stats = model_activation_stats or {} @@ -1040,6 +1209,8 @@ class PostTrainLinearQuantizer(Quantizer): self.replacement_factory[distiller.modules.BatchMatmul] = factory_matmul self.replacement_factory[nn.Embedding] = replace_embedding + self.default_repalcement_fn = replace_fake_quant + save_dir = msglogger.logdir if hasattr(msglogger, 'logdir') else '.' self.save_per_layer_parameters(save_dir) @@ -1067,7 +1238,8 @@ class PostTrainLinearQuantizer(Quantizer): model_activation_stats=args.qe_stats_file, clip_n_stds=args.qe_clip_n_stds, scale_approx_mult_bits=args.qe_scale_approx_bits, - overrides=overrides) + overrides=overrides, + inputs_quant_auto_fallback=True) def save_per_layer_parameters(self, save_dir=''): defaults = OrderedDict(self.model.quantizer_metadata['params']) @@ -1110,6 +1282,7 @@ class PostTrainLinearQuantizer(Quantizer): if not self.has_bidi_distiller_lstm: self._apply_bn_folding(dummy_input) self._apply_activation_stats_fusions() + self._apply_fuse_relu() else: self._apply_bidi_distiller_lstm_stats_fusion() @@ -1117,6 +1290,14 @@ class PostTrainLinearQuantizer(Quantizer): save_path = os.path.join(save_dir, 'quant_stats_after_prepare_model.yaml') distiller.yaml_ordered_save(save_path, self.model_activation_stats) msglogger.info('Updated stats saved to ' + save_path) + # for module_name, override in self.module_overrides_map.items(): + # # Hack to bypass Quantizer pre-override check - + # # Quantizer class checks `qbit.acts` and `qbit.wts` before applying overrides + # # but since fp16 doesn't act as an intN - we need to override these + # # tensors to bypass the check + # if (override.get('fp16', False) or override.get('fpq_module', False)) and \ + # not override.get('fake', False): + # self.module_qbits_map[module_name] = QBits('fp', None, None) def _clip_stats(self, entry, min_val, max_val): if entry['max'] < min_val: @@ -1131,7 +1312,7 @@ class PostTrainLinearQuantizer(Quantizer): def _apply_bn_folding(self, dummy_input): msglogger.info('Applying batch-norm folding ahead of post-training quantization') - mt.fold_batch_norms_inference(self.model, adjacency_map=self.adjacency_map) + mt.fold_batch_norms(self.model, adjacency_map=self.adjacency_map, inference=True) # After BN folding model need to re-generate the adjacency map summary_graph = distiller.SummaryGraph(self.model, dummy_input) @@ -1145,16 +1326,20 @@ class PostTrainLinearQuantizer(Quantizer): named_modules = OrderedDict(self.model.named_modules()) model_stats = self.model_activation_stats for n, m in named_modules.items(): - try: - # Look for the mark left by distiller.model_transforms.fold_batch_norms - folded_bn_module = distiller.normalize_module_name(m.fused_modules[0]) - # Propagate the output stats of the folded BN module to this module - # If stats were collected after folding was applied, then stats for the BN module won't exist, - # in which case we just move on - model_stats[distiller.normalize_module_name(n)]['output'] = model_stats.pop(folded_bn_module)['output'] - msglogger.debug(' {} --> {}'.format(folded_bn_module, n)) - except (AttributeError, KeyError): + # Look for the mark left by distiller.model_transforms.fold_batch_norms + fused_modules = getattr(m, 'fused_modules', None) + if fused_modules is None: continue + folded_bn_module = distiller.normalize_module_name(fused_modules[0]) + + # Propagate the output stats of the folded BN module to this module + # If stats were collected after folding was applied, then stats for the BN module won't exist, + # in which case we just move on + folded_bn_stats = model_stats.pop(folded_bn_module, None) + if folded_bn_stats is None: + continue + model_stats[distiller.normalize_module_name(n)]['output'] = folded_bn_stats['output'] + msglogger.debug(' {} --> {}'.format(folded_bn_module, n)) def _apply_activation_stats_fusions(self): # Now we look for certain "fusions" of layers and activations @@ -1169,7 +1354,8 @@ class PostTrainLinearQuantizer(Quantizer): named_modules = OrderedDict(self.model.named_modules()) model_stats = self.model_activation_stats for n, m in named_modules.items(): - if distiller.has_children(m) or n not in self.adjacency_map or len(self.adjacency_map[n].successors) != 1: + if (distiller.has_children(m) and not isinstance(m, SimulatedFoldedBatchNorm) )\ + or n not in self.adjacency_map or len(self.adjacency_map[n].successors) != 1: continue successor = self.adjacency_map[n].successors[0] n = distiller.normalize_module_name(n) @@ -1192,23 +1378,53 @@ class PostTrainLinearQuantizer(Quantizer): succ_type = 'Sigmoid' succ_stats = None + # Set the clipping values if succ_type == 'Relu': # ReLU zeros out all negative values, so there's no need to quantize them - msglogger.debug(' Module {} followed by Relu, updating stats'.format(n)) - if succ_stats is not None: - m_stats['output'] = deepcopy(succ_stats['output']) - succ_stats['inputs'][0] = deepcopy(succ_stats['output']) - else: - msglogger.debug(" Relu op not a module or post-split, can't update mean and std".format(n)) - self._clip_stats(m_stats['output'], 0., m_stats['output']['max']) + min_val = 0. + max_val = m_stats['output']['max'] elif succ_type == 'Sigmoid' or succ_type == 'Tanh': # Tanh / Sigmoid saturate at ~4 / ~6 respectively. No need to quantize their inputs outside # of these ranges - msglogger.debug(' Module {} followed by {}, updating stats'.format(n, succ_type)) - sat_val = 4. if succ_type == 'Tanh' else 6. - self._clip_stats(m_stats['output'], -sat_val, sat_val) - if succ_stats is not None: - succ_stats['inputs'][0] = deepcopy(m_stats['output']) + max_val = 4. if succ_type == 'Tanh' else 6. + min_val = -max_val + elif isinstance(named_modules.get(successor.name, None), nn.ReLU6): + succ_type = 'ReLU6' + # ReLU zeros out all negative values, so there's no need to quantize them + min_val = 0. + max_val = min(m_stats['output']['max'], 6) + else: + continue + + # Clip the stats + msglogger.debug(' Module {} followed by {}, updating stats'.format(n, succ_type)) + self._clip_stats(m_stats['output'], min_val, max_val) + if succ_stats is not None: + succ_stats['inputs'][0] = deepcopy(m_stats['output']) + + def _apply_fuse_relu(self): + """Fuses ReLU layers to the linear layers before them.""" + model_overrides = self.module_overrides_map + qbits_map = self.module_qbits_map + named_modules = dict(self.model.named_modules()) + for n, m in named_modules.items(): + # Don't fuse if the module isn't quantized: + qbits = qbits_map.get(n, QBits(None, None, None)) + if qbits.acts is None and qbits.wts is None: + continue + if (distiller.has_children(m) and not isinstance(m, SimulatedFoldedBatchNorm))\ + or n not in self.adjacency_map or len(self.adjacency_map[n].successors) != 1: + continue + successor = self.adjacency_map[n].successors[0] + successor_module = named_modules.get(successor.name, None) + # Add half range clipping to module overrides + m_override = model_overrides.get(n, OrderedDict()) + model_overrides[n] = m_override + if successor.name in named_modules and isinstance(successor_module, (nn.ReLU, nn.ReLU6)): + m_override['clip_half_range'] = True + m_override = model_overrides.get(successor.name, OrderedDict()) + m_override['make_identity'] = True + model_overrides[successor.name] = m_override def _apply_bidi_distiller_lstm_stats_fusion(self): distiller_lstm_cells = [n for n, m in self.model.named_modules() if @@ -1220,6 +1436,18 @@ class PostTrainLinearQuantizer(Quantizer): sat_val = 6. self._clip_stats(self.model_activation_stats[name]['output'], -sat_val, sat_val) + def _post_prepare_model(self): + if isinstance(self.model, nn.DataParallel): + # We restore the buffers to the master-GPU of the modules: + device = self.model.src_device_obj + m = self.model.module + for param in m.parameters(): + param.data = param.data.to(device) + for buffer in m.buffers(): + buffer.data = buffer.data.to(device) + + + ############################################################################### # Quantization-aware training diff --git a/distiller/summary_graph.py b/distiller/summary_graph.py index 5df0c6b6634f3ff3994d0be4d4666b8b4bb71f28..5c951f9b79d0763005de80c91406b87de77ff6ab 100755 --- a/distiller/summary_graph.py +++ b/distiller/summary_graph.py @@ -72,6 +72,10 @@ class SummaryGraph(object): def __init__(self, model, dummy_input, apply_scope_name_workarounds=True): self._src_model = model + self._named_modules = OrderedDict(model.named_modules()) + self._adj_map = None + self._layers_topological_order = None + self._top_level_ops = set() model_clone = distiller.make_non_parallel_copy(model) # Switch all instances of torch.nn.ModuleList in the model to our DistillerModuleList @@ -82,6 +86,7 @@ class SummaryGraph(object): device = distiller.model_device(model_clone) dummy_input = distiller.convert_tensors_recursively_to(dummy_input, device=device) + self.dummy_input = dummy_input trace, _ = jit.get_trace_graph(model_clone, dummy_input, _force_outplace=True) # As of PyTorch 1.1.0, ONNX trace optimization has two issues that result in incorrect scope names @@ -185,7 +190,9 @@ class SummaryGraph(object): # so we "unroll" them. same_module_cnt = len(self.module_ops_map[module_name]) if same_module_cnt: - new_op['name'] += "__" + str(same_module_cnt) + # TODO: Was this meant to be applied only to 'top_level_ops'? Also, it's not + # applied to the first module that had the same name + new_op['name'] += "_%s_%d" % (new_op['type'], same_module_cnt) self.module_ops_map[module_name].append(new_op['name']) # Finally we register the new op in the ops collection @@ -506,6 +513,13 @@ class SummaryGraph(object): self._src_model, normalized_layer_name) yield sgraph_layer_name, param_name, param + def _dedicated_module_check(self, n, dedicated_modules_only=False): + if not dedicated_modules_only: + return True + module_name = self.ops[n]['module-name'] + module = self._named_modules[module_name] + return len(self.module_ops_map[module_name]) == 1 and not distiller.has_children(module) + def adjacency_map(self, dedicated_modules_only=False): """Returns a mapping from each op in the graph to its immediate predecessors and successors. @@ -519,35 +533,107 @@ class SummaryGraph(object): associated with a dedicated module within the underlying model. Examples of this will be functional calls, such as "F.relu()", and tensor operations, such as "t3 = t1 + t2". """ + if self._adj_map and not dedicated_modules_only: + return self._adj_map adj_map = OrderedDict() - named_modules = OrderedDict(self._src_model.named_modules()) for op_name, op in self.ops.items(): - def dedicated_module_check(n): - if not dedicated_modules_only: - return True - module_name = self.ops[n]['module-name'] - module = named_modules[module_name] - return len(self.module_ops_map[module_name]) == 1 and not distiller.has_children(module) def op_meta(n): return OpSimpleMetadata(distiller.denormalize_module_name(self._src_model, n), self.ops[n]['type']) - if not dedicated_module_check(op_name): + if not self._dedicated_module_check(op_name, dedicated_modules_only): continue entry = AdjacentsEntry(op_meta(op_name)) # Find the immediate preceding and succeeding modules. Depth of 1 gets us the # input and output tensors, depth of 2 gets the actual modules entry.predecessors = [op_meta(n) for n in self.predecessors(op, 2, denorm_names=False) - if dedicated_module_check(n)] + if self._dedicated_module_check(n, dedicated_modules_only)] entry.successors = [op_meta(n) for n in self.successors(op, 2, denorm_names=False) - if dedicated_module_check(n)] + if self._dedicated_module_check(n, dedicated_modules_only)] adj_map[entry.op_meta.name] = entry - + self._adj_map = adj_map return adj_map + def layers_topological_order(self, recurrent=False): + """ + Prepares an ordered list of layers to quantize sequentially. This list has all the layers ordered by their + topological order in the graph. + Args: + recurrent (bool): indication on whether the model might have recurrent connections. + """ + if self._layers_topological_order: + return self._layers_topological_order + adj_map = self.adjacency_map() + ranked_ops = OrderedDict([(k, _OpRank(v, 0)) for k, v in adj_map.items()]) + + def _recurrent_ancestor(ranked_ops_dict, dest_op_name, src_op_name): + def _is_descendant(parent_op_name, dest_op_name): + successors_names = [op.name for op in adj_map[parent_op_name].successors] + if dest_op_name in successors_names: + return True + for succ_name in successors_names: + if _is_descendant(succ_name, dest_op_name): + return True + return False + + return _is_descendant(dest_op_name, src_op_name) and \ + (0 < ranked_ops_dict[dest_op_name].rank < ranked_ops_dict[src_op_name].rank) + + def rank_op(ranked_ops_dict, op_name, rank): + ranked_ops_dict[op_name].rank = rank + for child_op in adj_map[op_name].successors: + # In recurrent models: if a successor is also an ancestor - we don't increment its rank. + if not recurrent or not _recurrent_ancestor(ranked_ops_dict, child_op.name, op_name): + rank_op(ranked_ops_dict, child_op.name, ranked_ops_dict[op_name].rank + 1) + + roots = [k for k, v in adj_map.items() if len(v.predecessors) == 0] + for root_op_name in roots: + rank_op(ranked_ops, root_op_name, 0) + + # Take only the modules from the original model + module_dict = dict(self._src_model.named_modules()) + ret = sorted([k for k in ranked_ops.keys() if k in module_dict], + key=lambda k: ranked_ops[k].rank) + # Check that only the actual roots have a rank of 0 + assert {k for k in ret if ranked_ops[k].rank == 0} <= set(roots) + self._layers_topological_order = ret + return ret + + def top_level_ops(self): + if self._top_level_ops: + return self._top_level_ops + for op_name in self.ops: + if not self.predecessors(op_name, 1): + self._top_level_ops.add(op_name) + return self._top_level_ops + + def missing_modules(self): + """ + Returns a list of ops that aren't registered as modules. + """ + return [op_name for op_name in self.adjacency_map() + if not self._dedicated_module_check(op_name, True)] + + +class _OpRank: + def __init__(self, adj_entry, rank=None): + self.adj_entry = adj_entry + self._rank = rank or 0 + + @property + def rank(self): + return self._rank + + @rank.setter + def rank(self, val): + self._rank = max(val, self._rank) + + def __repr__(self): + return '_OpRank(\'%s\' | %d)' % (self.adj_entry.op_meta.name, self.rank) + class OpSimpleMetadata(object): def __init__(self, name, type): diff --git a/distiller/utils.py b/distiller/utils.py index d1c1feba9b825eab9ba81ef2b229b0da925d8d65..3b342b57a5d3181436986984e2e401e49fba14fc 100755 --- a/distiller/utils.py +++ b/distiller/utils.py @@ -739,3 +739,35 @@ def convert_tensors_recursively_to(val, *args, **kwargs): return val + +# TODO: Is this needed? +def model_setattr(model, attr_name, val, register=False): + """ + Sets attribute of a model, through the entire hierarchy. + Args: + model (nn.Module): the model. + attr_name (str): the attribute name as shown by model.named_<parameters/modules/buffers>() + val: the value of the attribute + register (bool): if True - register_buffer(val) if val is a torch.Tensor and + register_parameter(val) if it's an nn.Parameter. + """ + def split_name(name): + if '.' in name: + return name.rsplit('.', 1) + else: + return '', name + modules_dict = OrderedDict(model.named_modules()) + lowest_depth_container_name, lowest_depth_attr_name = split_name(attr_name) + while lowest_depth_container_name and lowest_depth_container_name not in modules_dict: + container_name, attr = split_name(lowest_depth_container_name) + lowest_depth_container_name = container_name + lowest_depth_attr_name = '%s%s' % (attr, lowest_depth_attr_name) + lowest_depth_container = modules_dict[lowest_depth_container_name] # type: nn.Module + + if register and torch.is_tensor(val): + if isinstance(val, nn.Parameter): + lowest_depth_container.register_parameter(lowest_depth_attr_name, val) + else: + lowest_depth_container.register_buffer(lowest_depth_attr_name, val) + else: + setattr(lowest_depth_container, lowest_depth_attr_name, val) diff --git a/examples/quantization/post_train_quant/resnet18_imagenet_post_train_input_overrides.yaml b/examples/quantization/post_train_quant/resnet18_imagenet_post_train_input_overrides.yaml new file mode 100644 index 0000000000000000000000000000000000000000..7427dcdd625c8efc42a9ddae5b005096dd8b31f1 --- /dev/null +++ b/examples/quantization/post_train_quant/resnet18_imagenet_post_train_input_overrides.yaml @@ -0,0 +1,75 @@ +# Specifying quantization settings for layer INPUTS +# ------------------------------------------------- +# +# This configuration file demonstrates how to override the default behavior of PostTrainLinearQuantizer +# with regards to quantizing inputs of layers. +# +# THIS APPLIES ONLY TO PostTraingLinearQuantizer +# +# 1. By default, settings for quantizing activations are applied to OUTPUTs of layers +# +# 2. When the output of a layer is quantized, the quantization settings (aka metadata) are attached to +# the tensor. +# +# 3. Certain layers expect quantized inputs. That is - they expect input tensors that have quantization +# metadata attached. +# +# 4. But, there are cases where this metadata will not be available, such as: +# 4.1 "Root" inputs of the model (which are not the output of any layer) +# 4.2 When two quantized layers are separated by one or more non-quantized operations. +# This can happen if the user explicitly requested that these operations not be quantized, or if +# they are operations that can't be detected by the quantizer. +# +# 5. In these cases, the default behavior of the quantizer is to quantize the layer inputs according +# to the settings that were provided for the layer's outputs (see item 1). +# This behavior is controlled by the 'inputs_quant_auto_fallback' parameter. By default it is set +# to True. +# +# 6. One can override this default behaviour and explicitly set settings for inputs. Below we show a +# couple of examples for this. +# Note that when the input tensor does have quantization metadata attached, specifying explicit +# settings isn't allowed, and an error will be raised. This is done in order to maintain maximum +# consistency with quantization settings flowing through the model. + +quantizers: + post_train_quantizer: + class: PostTrainLinearQuantizer + bits_activations: 8 + bits_parameters: 8 + bits_accum: 32 + mode: ASYMMETRIC_UNSIGNED + # Path to stats file assuming this is being invoked from the 'classifier_compression' example directory + model_activation_stats: ../quantization/post_train_quant/stats/resnet18_quant_stats.yaml + per_channel_wts: True + clip_acts: AVG + + # If this is set, when a quantized input is required but the tensor doesn't contain quantization metadata, + # the input will be quantized according to the module's output settings + inputs_quant_auto_fallback: False + + overrides: + conv1: + # The input to the first layer in the model will never have quantization metadata (item 4.1 above) + # Since layers could have multiple inputs, we specify the index of the required input as a key + input_overrides: + 0: + # Shorthand to take the quantization settings of the output (ignores any other settings) + from_outputs: True + + fc: + # Here we show an example of mixing output overrides and input overrides + + # We don't clip the output of the last layer + clip_acts: None + + # In ResNet, the FC layer has a view op before, which kills the quantization metadata (item 4.2 above). + # So we override. + # But here we don't want to take the settings from the output, where we just set clip_acts to None. + # So we specify clip_acts explicitly. + input_overrides: + 0: + # Example of setting the actual value. Applicable only if 'from_outputs' isn't set. + # The following keys are supported: 'bits_activations', 'mode', 'clip_acts', 'clip_n_stds' + # Any key not explicitly set will default to the output setting + clip_acts: AVG + diff --git a/tests/full_flow_tests.py b/tests/full_flow_tests.py index ef466198d1162a63fe8051224fbd92c9808be05e..d8617b7f78b263690b058714cd27c251a61d47e1 100755 --- a/tests/full_flow_tests.py +++ b/tests/full_flow_tests.py @@ -123,7 +123,7 @@ test_configs = [ TestConfig('--arch simplenet_cifar --epochs 2', DS_CIFAR, accuracy_checker, [44.460, 91.230]), TestConfig('-a resnet20_cifar --resume {0} --quantize-eval --evaluate --qe-clip-acts avg --qe-no-clip-layers {1}'. format(os.path.join(examples_root, 'ssl', 'checkpoints', 'checkpoint_trained_dense.pth.tar'), 'fc'), - DS_CIFAR, accuracy_checker, [91.64, 99.63]), + DS_CIFAR, accuracy_checker, [91.57, 99.62]), TestConfig('-a preact_resnet20_cifar --epochs 2 --compress {0}'. format(os.path.join('full_flow_tests', 'preact_resnet20_cifar_pact_test.yaml')), DS_CIFAR, accuracy_checker, [44.370, 89.640]), diff --git a/tests/test_model_transforms.py b/tests/test_model_transforms.py index 69a4fb0979392f3dbda64d8eb02e2ff38532ace2..9ac33534ef379b67ad9dffd77e1a7d9ebfeb0398 100644 --- a/tests/test_model_transforms.py +++ b/tests/test_model_transforms.py @@ -225,7 +225,7 @@ def test_fuse_modules_with_pre_exist_adj_map(): ) def test_fold_batch_norms_inference_no_fold(model, input_shape): orig_model = deepcopy(model) - folded_model = mt.fold_batch_norms_inference(model, dummy_input=torch.randn(input_shape)) + folded_model = mt.fold_batch_norms(model, dummy_input=torch.randn(input_shape), inference=True) for (n_orig, m_orig), (n_folded, m_folded) in zip(orig_model.named_modules(), folded_model.named_modules()): assert n_folded == n_orig assert type(m_folded) == type(m_orig) @@ -255,7 +255,7 @@ def test_fold_batch_norms_inference(model, input_shape): model.eval() orig_model = deepcopy(model) dummy_input = torch.randn(input_shape) - folded_model = mt.fold_batch_norms_inference(model, dummy_input=dummy_input) + folded_model = mt.fold_batch_norms(model, dummy_input=dummy_input, inference=True) assert type(folded_model.seq[0]) == type(orig_model.seq[0]) assert type(folded_model.seq[1]) == nn.Identity diff --git a/tests/test_post_train_quant.py b/tests/test_post_train_quant.py index 8bdc3fb69034ee53b0c531138561762db9912bb5..685c956cf660ad71898d076e222ad5df878d5602 100644 --- a/tests/test_post_train_quant.py +++ b/tests/test_post_train_quant.py @@ -25,11 +25,28 @@ from copy import deepcopy from distiller.quantization import RangeLinearQuantParamLayerWrapper, LinearQuantMode, ClipMode, \ RangeLinearQuantConcatWrapper, RangeLinearQuantEltwiseMultWrapper, RangeLinearQuantEltwiseAddWrapper, \ PostTrainLinearQuantizer +from distiller.quantization.range_linear import _get_quant_params_from_tensor, _get_quant_params_from_stats_dict,\ + TensorQuantMetadata +from distiller.quantization import q_utils from distiller.data_loggers import QuantCalibrationStatsCollector, collector_context import distiller.modules from common import WrappedSequential +def attach_quant_metadata(t, num_bits, quant_mode, stats=None, clip_mode=ClipMode.NONE, per_channel=False, + num_stds=None, scale_approx_mult_bits=None): + if stats is None: + scale, zp = _get_quant_params_from_tensor(t, num_bits, quant_mode, clip_mode, per_channel, num_stds, + scale_approx_mult_bits) + else: + scale, zp = _get_quant_params_from_stats_dict(stats, num_bits, quant_mode, clip_mode, num_stds, + scale_approx_mult_bits) + signed = quant_mode != LinearQuantMode.ASYMMETRIC_UNSIGNED + min_q_val, max_q_val = q_utils.get_quantized_range(num_bits, signed) + t.quant_metadata = TensorQuantMetadata(scale, zp, min_q_val, max_q_val) + return t + + ############################################################################### # Test Convolution ############################################################################### @@ -121,6 +138,10 @@ def test_conv_layer_wrapper(conv_input, conv_weights, mode, clip_acts, per_chann model = RangeLinearQuantParamLayerWrapper(layer, 8, 8, mode=mode, clip_acts=clip_acts, per_channel_wts=per_channel_wts, activation_stats=conv_stats) + input_stats = None if conv_stats is None else conv_stats['inputs'][0] + conv_input = attach_quant_metadata(conv_input, 8, mode, stats=input_stats, clip_mode=clip_acts, + per_channel=False, num_stds=None, scale_approx_mult_bits=None) + with pytest.raises(RuntimeError): model(conv_input) @@ -174,6 +195,9 @@ def test_linear_layer_wrapper(linear_input, linear_weights, linear_bias, model = RangeLinearQuantParamLayerWrapper(layer, 8, 8, mode=mode, clip_acts=clip_acts, per_channel_wts=per_channel_wts) + linear_input = attach_quant_metadata(linear_input, 8, mode, stats=None, clip_mode=clip_acts, + per_channel=False, num_stds=None, scale_approx_mult_bits=None) + with pytest.raises(RuntimeError): model(linear_input) @@ -196,7 +220,7 @@ def inputs(): in_1_b_0 = torch.tensor([[[[-3, 6], [0, 8]], [[4, 10], [-7, 1]]]], dtype=torch.float32) in_1_b_1 = torch.tensor([[[[-100, 50], [6, 12]], [[80, -30], [-16, 3]]]], dtype=torch.float32) in_1 = torch.cat((in_1_b_0, in_1_b_1), 0) - return in_0, in_1 + return [in_0, in_1] input_stats = OrderedDict() @@ -264,6 +288,11 @@ def test_concat_layer_wrapper(inputs, concat_stats, mode, clip_acts, expected_ou # Check exception on no stats RangeLinearQuantConcatWrapper(layer, 8, mode, clip_acts, activation_stats=None) + for idx in range(len(inputs)): + inputs[idx] = attach_quant_metadata(inputs[idx], 8, mode, stats=concat_stats['inputs'][idx], + clip_mode=clip_acts, per_channel=False, num_stds=None, + scale_approx_mult_bits=None) + model = RangeLinearQuantConcatWrapper(layer, 8, mode, clip_acts, concat_stats) model.eval() output = model(*inputs) @@ -319,6 +348,11 @@ def test_eltwise_mult_layer_wrapper(inputs, eltwise_mult_stats, mode, clip_acts, # Check exception on no stats RangeLinearQuantEltwiseMultWrapper(layer, 8, mode, clip_acts, activation_stats=None) + for idx in range(len(inputs)): + inputs[idx] = attach_quant_metadata(inputs[idx], 8, mode, stats=eltwise_mult_stats['inputs'][idx], + clip_mode=clip_acts, per_channel=False, num_stds=None, + scale_approx_mult_bits=None) + model = RangeLinearQuantEltwiseMultWrapper(layer, 8, mode, clip_acts, eltwise_mult_stats) model.eval() output = model(*inputs) @@ -374,6 +408,11 @@ def test_eltwise_add_layer_wrapper(inputs, eltwise_add_stats, mode, clip_acts, e # Check exception on no stats RangeLinearQuantEltwiseAddWrapper(layer, 8, mode, clip_acts, activation_stats=None) + for idx in range(len(inputs)): + inputs[idx] = attach_quant_metadata(inputs[idx], 8, mode, stats=eltwise_add_stats['inputs'][idx], + clip_mode=clip_acts, per_channel=False, num_stds=None, + scale_approx_mult_bits=None) + model = RangeLinearQuantEltwiseAddWrapper(layer, 8, mode, clip_acts, eltwise_add_stats) model.eval() output = model(*inputs) @@ -439,8 +478,8 @@ def test_override_no_clip(overrides, e_clip_acts, e_n_stds, rnn_model, rnn_model model_activation_stats=rnn_model_stats) quantizer.prepare_model(torch.randn(1, 1, 20)) assert isinstance(quantizer.model.rnn.cells[0].eltwisemult_hidden, RangeLinearQuantEltwiseMultWrapper) - assert quantizer.model.rnn.cells[0].eltwisemult_hidden.clip_acts == e_clip_acts - assert quantizer.model.rnn.cells[0].eltwisemult_hidden.clip_n_stds == e_n_stds + assert quantizer.model.rnn.cells[0].eltwisemult_hidden.output_quant_settings.clip_mode == e_clip_acts + assert quantizer.model.rnn.cells[0].eltwisemult_hidden.output_quant_settings.clip_n_stds == e_n_stds ############################################################################### @@ -551,7 +590,7 @@ def test_stats_fusion_just_bn(): @pytest.mark.parametrize( 'act_type, act_as_module, bn_out_stats, conv_out_expected_stats', [ - ('relu', True, stats_entry(-5., 5., -3., 3., 0., 0.5), None), + ('relu', True, stats_entry(-5., 5., -3., 3., 0., 0.5), stats_entry(0., 5., 0, 3., 0., 0.5)), ('relu', False, stats_entry(-5., 5., -3., 3., 0., 0.5), stats_entry(0., 5., 0, 3., 0., 0.5)), ('relu', False, stats_entry(1., 5., 2., 3., 2.5, 0.5), stats_entry(1., 5., 2., 3., 2.5, 0.5)), ('relu', False, stats_entry(-5., -1., -4., -2., -2.5, 0.5), stats_entry(0., 0, 0, 0., -2.5, 0.5)), @@ -577,16 +616,9 @@ def test_stats_fusion_sequential(act_type, act_as_module, bn_out_stats, conv_out expected = deepcopy(stats) expected.pop('bn') # After BN folding BN stats are removed - if act_type == 'relu': - if act_as_module: - expected['conv']['output'] = deepcopy(stats['act']['output']) - expected['act']['inputs'][0] = deepcopy(stats['act']['output']) - else: - expected['conv']['output'] = conv_out_expected_stats - else: - expected['conv']['output'] = conv_out_expected_stats - if act_as_module: - expected['act']['inputs'][0] = conv_out_expected_stats + expected['conv']['output'] = conv_out_expected_stats + if act_as_module: + expected['act']['inputs'][0] = conv_out_expected_stats assert quantizer.model_activation_stats == expected diff --git a/tests/test_quantizer.py b/tests/test_quantizer.py index 049e563eda6d0c8034a9414af0bef43b58181716..467aef8e3ddf64119dff7d5dd98c664a7845de68 100644 --- a/tests/test_quantizer.py +++ b/tests/test_quantizer.py @@ -396,13 +396,6 @@ def test_param_quantization(model, optimizer, qbits, overrides, explicit_expecte def test_overridable_args(model, optimizer, train_with_fp_copy): - model_copy = deepcopy(model) - conv_override = OrderedDict([(acts_key, None), (wts_key, None), (bias_key, None), ('prop', 123)]) - overrides = OrderedDict([('conv1', conv_override)]) - q = DummyQuantizer(model_copy, optimizer=optimizer, overrides=overrides, train_with_fp_copy=train_with_fp_copy) - pytest_raises_wrapper(ValueError, 'Expecting ValueError when overriding args without overriding bits', - q.prepare_model) - model_copy = deepcopy(model) conv_override = OrderedDict([(acts_key, 8), (wts_key, 8), (bias_key, 32), ('prop', 123), ('unexpetcted_prop', 456)]) overrides = OrderedDict([('conv1', conv_override)]) diff --git a/tests/test_summarygraph.py b/tests/test_summarygraph.py index 293f69bd4049ad41baf7d88a827c977689caed51..98feba80f7aefed82e3bd3e101e5266a44455d91 100755 --- a/tests/test_summarygraph.py +++ b/tests/test_summarygraph.py @@ -125,7 +125,7 @@ def test_layer_search(parallel, denorm_names): assert preds == prefix_strs(['conv1'], prefix) preds = g.predecessors_f('layer1.1.conv1', 'Conv', [], logging, denorm_names=denorm_names) - assert preds == prefix_strs(['layer1.0.conv2', 'conv1'], prefix) + assert preds == prefix_strs(['conv1', 'layer1.0.conv2'], prefix) def test_vgg(): @@ -306,7 +306,7 @@ def test_scope_name_workarounds(): # may fix them, in which case we can remove the workarounds sg = SummaryGraph(m, dummy_input, apply_scope_name_workarounds=False) names, types = zip(*[(op_name, op['type']) for op_name, op in sg.ops.items()]) - assert names == ('drop1', 'drop2', 'drop2__1', 'relu2', 'drop3') + assert names == ('drop1', 'drop2', 'drop2_Gemm_1', 'relu2', 'drop3') assert types == expected_types # Now test with the workarounds