From 3cde6c5eb68ed8e6e279b24f4d0283d7caed99ce Mon Sep 17 00:00:00 2001 From: Lev Zlotnik <46742999+levzlotnik@users.noreply.github.com> Date: Mon, 3 Jun 2019 15:14:10 +0300 Subject: [PATCH] [Breaking] PTQ: Removed special handling of clipping overrides * In PostTrainLinearQuantizer - moved 'clip_acts' and 'clip_n_stds' to overrides, removed 'no_clip_layers' parameter from __init__ * The 'no_clip_layers' command line argument REMAINS, handled in PostTrainLinearQuantizer.from_args() * Removed old code from comments, fixed warnings in test_post_train_quant.py * Updated tests * Update post-train quant sample YAML --- distiller/quantization/range_linear.py | 29 +++++----- .../resnet18_imagenet_post_train.yaml | 2 +- tests/full_flow_tests.py | 6 +- tests/test_post_train_quant.py | 57 ++++++++++++++++++- 4 files changed, 75 insertions(+), 19 deletions(-) diff --git a/distiller/quantization/range_linear.py b/distiller/quantization/range_linear.py index fb460c0..9545a05 100644 --- a/distiller/quantization/range_linear.py +++ b/distiller/quantization/range_linear.py @@ -768,9 +768,6 @@ class PostTrainLinearQuantizer(Quantizer): overrides (:obj:`OrderedDict`, optional): Overrides the layers quantization settings. mode (LinearQuantMode): Quantization mode to use (symmetric / asymmetric-signed / unsigned) clip_acts (ClipMode): Activations clipping mode to use - no_clip_layers (list): List of fully-qualified layer names for which activations clipping should not be done. - A common practice is to not clip the activations of the last layer before softmax. - Applicable only if clip_acts is True. per_channel_wts (bool): Enable quantization of weights using separate quantization parameters per output channel model_activation_stats (str / dict / OrderedDict): Either a path to activation stats YAML file, or a dictionary @@ -788,7 +785,7 @@ class PostTrainLinearQuantizer(Quantizer): to half precision, regardless of bits_activations/parameters/accum. """ def __init__(self, model, bits_activations=8, bits_parameters=8, bits_accum=32, - overrides=None, mode=LinearQuantMode.SYMMETRIC, clip_acts=ClipMode.NONE, no_clip_layers=None, + overrides=None, mode=LinearQuantMode.SYMMETRIC, clip_acts=ClipMode.NONE, per_channel_wts=False, model_activation_stats=None, fp16=False, clip_n_stds=None, scale_approx_mult_bits=None): super(PostTrainLinearQuantizer, self).__init__(model, bits_activations=bits_activations, @@ -817,32 +814,33 @@ class PostTrainLinearQuantizer(Quantizer): 'mode': str(mode).split('.')[1], 'clip_acts': _enum_to_str(clip_acts), 'clip_n_stds': clip_n_stds, - 'no_clip_layers': no_clip_layers, 'per_channel_wts': per_channel_wts, 'fp16': fp16, 'scale_approx_mult_bits': scale_approx_mult_bits}} def replace_param_layer(module, name, qbits_map, per_channel_wts=per_channel_wts, - mode=mode, fp16=fp16, scale_approx_mult_bits=scale_approx_mult_bits): + mode=mode, fp16=fp16, scale_approx_mult_bits=scale_approx_mult_bits, + clip_acts=clip_acts, clip_n_stds=clip_n_stds): if fp16: return FP16Wrapper(module) norm_name = distiller.utils.normalize_module_name(name) - clip = self.clip_acts if norm_name not in self.no_clip_layers else ClipMode.NONE + clip_acts = verify_clip_mode(clip_acts) return RangeLinearQuantParamLayerWrapper(module, qbits_map[name].acts, qbits_map[name].wts, - num_bits_accum=self.bits_accum, mode=mode, clip_acts=clip, + num_bits_accum=self.bits_accum, mode=mode, clip_acts=clip_acts, per_channel_wts=per_channel_wts, activation_stats=self.model_activation_stats.get(norm_name, None), clip_n_stds=clip_n_stds, scale_approx_mult_bits=scale_approx_mult_bits) def replace_non_param_layer(wrapper_type, module, name, qbits_map, fp16=fp16, - scale_approx_mult_bits=scale_approx_mult_bits): + scale_approx_mult_bits=scale_approx_mult_bits, + clip_acts=clip_acts, clip_n_stds=clip_n_stds): if fp16: return FP16Wrapper(module) norm_name = distiller.utils.normalize_module_name(name) - clip = self.clip_acts if norm_name not in self.no_clip_layers else ClipMode.NONE + clip_acts = verify_clip_mode(clip_acts) try: - return wrapper_type(module, qbits_map[name].acts, mode=mode, clip_acts=clip, + return wrapper_type(module, qbits_map[name].acts, mode=mode, clip_acts=clip_acts, activation_stats=self.model_activation_stats.get(norm_name, None), clip_n_stds=clip_n_stds, scale_approx_mult_bits=scale_approx_mult_bits) except NoStatsError: @@ -858,7 +856,6 @@ class PostTrainLinearQuantizer(Quantizer): stats=self.model_activation_stats.get(norm_name, None)) self.clip_acts = clip_acts - self.no_clip_layers = no_clip_layers or [] self.clip_n_stds = clip_n_stds self.model_activation_stats = model_activation_stats or {} self.bits_accum = bits_accum @@ -885,17 +882,21 @@ class PostTrainLinearQuantizer(Quantizer): return distiller.config_component_from_file_by_class(model, args.qe_config_file, 'PostTrainLinearQuantizer') else: + overrides = OrderedDict( + [(layer, OrderedDict([('clip_acts', 'NONE')])) + for layer in args.qe_no_clip_layers] + ) return cls(model, bits_activations=args.qe_bits_acts, bits_parameters=args.qe_bits_wts, bits_accum=args.qe_bits_accum, mode=args.qe_mode, clip_acts=args.qe_clip_acts, - no_clip_layers=args.qe_no_clip_layers, per_channel_wts=args.qe_per_channel, model_activation_stats=args.qe_stats_file, clip_n_stds=args.qe_clip_n_stds, - scale_approx_mult_bits=args.qe_scale_approx_bits) + scale_approx_mult_bits=args.qe_scale_approx_bits, + overrides=overrides) ############################################################################### diff --git a/examples/quantization/post_train_quant/resnet18_imagenet_post_train.yaml b/examples/quantization/post_train_quant/resnet18_imagenet_post_train.yaml index 0b6252b..0e07184 100644 --- a/examples/quantization/post_train_quant/resnet18_imagenet_post_train.yaml +++ b/examples/quantization/post_train_quant/resnet18_imagenet_post_train.yaml @@ -45,7 +45,6 @@ quantizers: model_activation_stats: ../quantization/post_train_quant/stats/resnet18_quant_stats.yaml per_channel_wts: True clip_acts: AVG - no_clip_layers: fc overrides: # First and last layers + element-wise add layers in 8-bits conv1: @@ -57,3 +56,4 @@ quantizers: fc: bits_weights: 8 bits_activations: 8 + clip_acts: NONE diff --git a/tests/full_flow_tests.py b/tests/full_flow_tests.py index 4a27458..01f15c6 100755 --- a/tests/full_flow_tests.py +++ b/tests/full_flow_tests.py @@ -116,9 +116,9 @@ TestConfig = namedtuple('TestConfig', ['args', 'dataset', 'checker_fn', 'checker test_configs = [ TestConfig('--arch simplenet_cifar --epochs 2', DS_CIFAR, accuracy_checker, [44.610, 92.080]), - TestConfig('-a resnet20_cifar --resume {0} --quantize-eval --evaluate'. - format(os.path.join(examples_root, 'ssl', 'checkpoints', 'checkpoint_trained_dense.pth.tar')), - DS_CIFAR, accuracy_checker, [91.710, 99.610]), + TestConfig('-a resnet20_cifar --resume {0} --quantize-eval --evaluate --qe-clip-acts avg --qe-no-clip-layers {1}'. + format(os.path.join(examples_root, 'ssl', 'checkpoints', 'checkpoint_trained_dense.pth.tar'), 'fc'), + DS_CIFAR, accuracy_checker, [91.55, 99.63]), TestConfig('-a preact_resnet20_cifar --epochs 2 --compress {0}'. format(os.path.join('full_flow_tests', 'preact_resnet20_cifar_pact_test.yaml')), DS_CIFAR, accuracy_checker, [54.590, 94.810]), diff --git a/tests/test_post_train_quant.py b/tests/test_post_train_quant.py index 4522e01..2514c21 100644 --- a/tests/test_post_train_quant.py +++ b/tests/test_post_train_quant.py @@ -16,10 +16,13 @@ import pytest import torch import torch.testing +import torch.nn as nn from collections import OrderedDict from distiller.quantization import RangeLinearQuantParamLayerWrapper, LinearQuantMode, ClipMode, \ - RangeLinearQuantConcatWrapper, RangeLinearQuantEltwiseMultWrapper, RangeLinearQuantEltwiseAddWrapper + RangeLinearQuantConcatWrapper, RangeLinearQuantEltwiseMultWrapper, RangeLinearQuantEltwiseAddWrapper, \ + PostTrainLinearQuantizer +from distiller.data_loggers import QuantCalibrationStatsCollector, collector_context import distiller.modules @@ -352,3 +355,55 @@ def test_eltwise_add_layer_wrapper(inputs, eltwise_add_stats, mode, clip_acts, e output = model(*inputs) torch.testing.assert_allclose(output, expected_output) + + +class DummyRNN(nn.Module): + def __init__(self, input_size, hidden_size, num_layers): + super(DummyRNN, self).__init__() + self.rnn = distiller.modules.DistillerLSTM(input_size, hidden_size, num_layers) + self.softmax = nn.Softmax(dim=-1) + + def forward(self, x, h=None): + y, h = self.rnn(x, h) + y = self.softmax(y) + return y, h + + +@pytest.fixture() +def rnn_model(): + return DummyRNN(20, 128, 1) + + +@pytest.fixture() +def rnn_model_stats(rnn_model): + collector = QuantCalibrationStatsCollector(rnn_model) + dummy_input = torch.randn(35, 50, 20) + with collector_context(collector): + y,h = rnn_model(dummy_input) + return collector.value() + + +@pytest.mark.parametrize( + "overrides, e_clip_acts, e_n_stds", + [ + (None, ClipMode.AVG, 0), + + (distiller.utils.yaml_ordered_load(""" + rnn.cells.0.eltwisemult_hidden: + clip_acts: NONE + """), ClipMode.NONE, 0), + + (distiller.utils.yaml_ordered_load(""" + rnn.cells.0.eltwisemult_hidden: + clip_acts: N_STD + clip_n_stds: 2 + """), ClipMode.N_STD, 2) + ] +) +def test_override_no_clip(overrides, e_clip_acts, e_n_stds, rnn_model, rnn_model_stats): + quantizer = PostTrainLinearQuantizer(rnn_model, clip_acts="AVG", clip_n_stds=0, overrides=overrides, + model_activation_stats=rnn_model_stats) + quantizer.prepare_model() + assert isinstance(quantizer.model.rnn.cells[0].eltwisemult_hidden, RangeLinearQuantEltwiseMultWrapper) + assert quantizer.model.rnn.cells[0].eltwisemult_hidden.clip_acts == e_clip_acts + assert quantizer.model.rnn.cells[0].eltwisemult_hidden.clip_n_stds == e_n_stds -- GitLab