From 3cde6c5eb68ed8e6e279b24f4d0283d7caed99ce Mon Sep 17 00:00:00 2001
From: Lev Zlotnik <46742999+levzlotnik@users.noreply.github.com>
Date: Mon, 3 Jun 2019 15:14:10 +0300
Subject: [PATCH] [Breaking] PTQ: Removed special handling of clipping
 overrides

* In PostTrainLinearQuantizer - moved 'clip_acts' and 'clip_n_stds'
  to overrides, removed 'no_clip_layers' parameter from __init__
* The 'no_clip_layers' command line argument REMAINS, handled in
  PostTrainLinearQuantizer.from_args()
* Removed old code from comments, fixed warnings in
  test_post_train_quant.py
* Updated tests
* Update post-train quant sample YAML
---
 distiller/quantization/range_linear.py        | 29 +++++-----
 .../resnet18_imagenet_post_train.yaml         |  2 +-
 tests/full_flow_tests.py                      |  6 +-
 tests/test_post_train_quant.py                | 57 ++++++++++++++++++-
 4 files changed, 75 insertions(+), 19 deletions(-)

diff --git a/distiller/quantization/range_linear.py b/distiller/quantization/range_linear.py
index fb460c0..9545a05 100644
--- a/distiller/quantization/range_linear.py
+++ b/distiller/quantization/range_linear.py
@@ -768,9 +768,6 @@ class PostTrainLinearQuantizer(Quantizer):
         overrides (:obj:`OrderedDict`, optional): Overrides the layers quantization settings.
         mode (LinearQuantMode): Quantization mode to use (symmetric / asymmetric-signed / unsigned)
         clip_acts (ClipMode): Activations clipping mode to use
-        no_clip_layers (list): List of fully-qualified layer names for which activations clipping should not be done.
-            A common practice is to not clip the activations of the last layer before softmax.
-            Applicable only if clip_acts is True.
         per_channel_wts (bool): Enable quantization of weights using separate quantization parameters per
             output channel
         model_activation_stats (str / dict / OrderedDict): Either a path to activation stats YAML file, or a dictionary
@@ -788,7 +785,7 @@ class PostTrainLinearQuantizer(Quantizer):
         to half precision, regardless of bits_activations/parameters/accum.
     """
     def __init__(self, model, bits_activations=8, bits_parameters=8, bits_accum=32,
-                 overrides=None, mode=LinearQuantMode.SYMMETRIC, clip_acts=ClipMode.NONE, no_clip_layers=None,
+                 overrides=None, mode=LinearQuantMode.SYMMETRIC, clip_acts=ClipMode.NONE,
                  per_channel_wts=False, model_activation_stats=None, fp16=False, clip_n_stds=None,
                  scale_approx_mult_bits=None):
         super(PostTrainLinearQuantizer, self).__init__(model, bits_activations=bits_activations,
@@ -817,32 +814,33 @@ class PostTrainLinearQuantizer(Quantizer):
                                                     'mode': str(mode).split('.')[1],
                                                     'clip_acts': _enum_to_str(clip_acts),
                                                     'clip_n_stds': clip_n_stds,
-                                                    'no_clip_layers': no_clip_layers,
                                                     'per_channel_wts': per_channel_wts,
                                                     'fp16': fp16,
                                                     'scale_approx_mult_bits': scale_approx_mult_bits}}
 
         def replace_param_layer(module, name, qbits_map, per_channel_wts=per_channel_wts,
-                                mode=mode, fp16=fp16, scale_approx_mult_bits=scale_approx_mult_bits):
+                                mode=mode, fp16=fp16, scale_approx_mult_bits=scale_approx_mult_bits,
+                                clip_acts=clip_acts, clip_n_stds=clip_n_stds):
             if fp16:
                 return FP16Wrapper(module)
             norm_name = distiller.utils.normalize_module_name(name)
-            clip = self.clip_acts if norm_name not in self.no_clip_layers else ClipMode.NONE
+            clip_acts = verify_clip_mode(clip_acts)
             return RangeLinearQuantParamLayerWrapper(module, qbits_map[name].acts, qbits_map[name].wts,
-                                                     num_bits_accum=self.bits_accum, mode=mode, clip_acts=clip,
+                                                     num_bits_accum=self.bits_accum, mode=mode, clip_acts=clip_acts,
                                                      per_channel_wts=per_channel_wts,
                                                      activation_stats=self.model_activation_stats.get(norm_name, None),
                                                      clip_n_stds=clip_n_stds,
                                                      scale_approx_mult_bits=scale_approx_mult_bits)
 
         def replace_non_param_layer(wrapper_type, module, name, qbits_map, fp16=fp16,
-                                    scale_approx_mult_bits=scale_approx_mult_bits):
+                                    scale_approx_mult_bits=scale_approx_mult_bits,
+                                    clip_acts=clip_acts, clip_n_stds=clip_n_stds):
             if fp16:
                 return FP16Wrapper(module)
             norm_name = distiller.utils.normalize_module_name(name)
-            clip = self.clip_acts if norm_name not in self.no_clip_layers else ClipMode.NONE
+            clip_acts = verify_clip_mode(clip_acts)
             try:
-                return wrapper_type(module, qbits_map[name].acts, mode=mode, clip_acts=clip,
+                return wrapper_type(module, qbits_map[name].acts, mode=mode, clip_acts=clip_acts,
                                     activation_stats=self.model_activation_stats.get(norm_name, None),
                                     clip_n_stds=clip_n_stds, scale_approx_mult_bits=scale_approx_mult_bits)
             except NoStatsError:
@@ -858,7 +856,6 @@ class PostTrainLinearQuantizer(Quantizer):
                                                stats=self.model_activation_stats.get(norm_name, None))
 
         self.clip_acts = clip_acts
-        self.no_clip_layers = no_clip_layers or []
         self.clip_n_stds = clip_n_stds
         self.model_activation_stats = model_activation_stats or {}
         self.bits_accum = bits_accum
@@ -885,17 +882,21 @@ class PostTrainLinearQuantizer(Quantizer):
             return distiller.config_component_from_file_by_class(model, args.qe_config_file,
                                                                  'PostTrainLinearQuantizer')
         else:
+            overrides = OrderedDict(
+                [(layer, OrderedDict([('clip_acts', 'NONE')]))
+                 for layer in args.qe_no_clip_layers]
+            )
             return cls(model,
                        bits_activations=args.qe_bits_acts,
                        bits_parameters=args.qe_bits_wts,
                        bits_accum=args.qe_bits_accum,
                        mode=args.qe_mode,
                        clip_acts=args.qe_clip_acts,
-                       no_clip_layers=args.qe_no_clip_layers,
                        per_channel_wts=args.qe_per_channel,
                        model_activation_stats=args.qe_stats_file,
                        clip_n_stds=args.qe_clip_n_stds,
-                       scale_approx_mult_bits=args.qe_scale_approx_bits)
+                       scale_approx_mult_bits=args.qe_scale_approx_bits,
+                       overrides=overrides)
 
 
 ###############################################################################
diff --git a/examples/quantization/post_train_quant/resnet18_imagenet_post_train.yaml b/examples/quantization/post_train_quant/resnet18_imagenet_post_train.yaml
index 0b6252b..0e07184 100644
--- a/examples/quantization/post_train_quant/resnet18_imagenet_post_train.yaml
+++ b/examples/quantization/post_train_quant/resnet18_imagenet_post_train.yaml
@@ -45,7 +45,6 @@ quantizers:
     model_activation_stats: ../quantization/post_train_quant/stats/resnet18_quant_stats.yaml
     per_channel_wts: True
     clip_acts: AVG
-    no_clip_layers: fc
     overrides:
     # First and last layers + element-wise add layers in 8-bits
       conv1:
@@ -57,3 +56,4 @@ quantizers:
       fc:
         bits_weights: 8
         bits_activations: 8
+        clip_acts: NONE
diff --git a/tests/full_flow_tests.py b/tests/full_flow_tests.py
index 4a27458..01f15c6 100755
--- a/tests/full_flow_tests.py
+++ b/tests/full_flow_tests.py
@@ -116,9 +116,9 @@ TestConfig = namedtuple('TestConfig', ['args', 'dataset', 'checker_fn', 'checker
 
 test_configs = [
     TestConfig('--arch simplenet_cifar --epochs 2', DS_CIFAR, accuracy_checker, [44.610, 92.080]),
-    TestConfig('-a resnet20_cifar --resume {0} --quantize-eval --evaluate'.
-               format(os.path.join(examples_root, 'ssl', 'checkpoints', 'checkpoint_trained_dense.pth.tar')),
-               DS_CIFAR, accuracy_checker, [91.710, 99.610]),
+    TestConfig('-a resnet20_cifar --resume {0} --quantize-eval --evaluate --qe-clip-acts avg --qe-no-clip-layers {1}'.
+               format(os.path.join(examples_root, 'ssl', 'checkpoints', 'checkpoint_trained_dense.pth.tar'), 'fc'),
+               DS_CIFAR, accuracy_checker, [91.55, 99.63]),
     TestConfig('-a preact_resnet20_cifar --epochs 2 --compress {0}'.
                format(os.path.join('full_flow_tests', 'preact_resnet20_cifar_pact_test.yaml')),
                DS_CIFAR, accuracy_checker, [54.590, 94.810]),
diff --git a/tests/test_post_train_quant.py b/tests/test_post_train_quant.py
index 4522e01..2514c21 100644
--- a/tests/test_post_train_quant.py
+++ b/tests/test_post_train_quant.py
@@ -16,10 +16,13 @@
 import pytest
 import torch
 import torch.testing
+import torch.nn as nn
 from collections import OrderedDict
 
 from distiller.quantization import RangeLinearQuantParamLayerWrapper, LinearQuantMode, ClipMode, \
-    RangeLinearQuantConcatWrapper, RangeLinearQuantEltwiseMultWrapper, RangeLinearQuantEltwiseAddWrapper
+    RangeLinearQuantConcatWrapper, RangeLinearQuantEltwiseMultWrapper, RangeLinearQuantEltwiseAddWrapper, \
+    PostTrainLinearQuantizer
+from distiller.data_loggers import QuantCalibrationStatsCollector, collector_context
 import distiller.modules
 
 
@@ -352,3 +355,55 @@ def test_eltwise_add_layer_wrapper(inputs, eltwise_add_stats, mode, clip_acts, e
     output = model(*inputs)
 
     torch.testing.assert_allclose(output, expected_output)
+
+
+class DummyRNN(nn.Module):
+    def __init__(self, input_size, hidden_size, num_layers):
+        super(DummyRNN, self).__init__()
+        self.rnn = distiller.modules.DistillerLSTM(input_size, hidden_size, num_layers)
+        self.softmax = nn.Softmax(dim=-1)
+
+    def forward(self, x, h=None):
+        y, h = self.rnn(x, h)
+        y = self.softmax(y)
+        return y, h
+
+
+@pytest.fixture()
+def rnn_model():
+    return DummyRNN(20, 128, 1)
+
+
+@pytest.fixture()
+def rnn_model_stats(rnn_model):
+    collector = QuantCalibrationStatsCollector(rnn_model)
+    dummy_input = torch.randn(35, 50, 20)
+    with collector_context(collector):
+        y,h = rnn_model(dummy_input)
+    return collector.value()
+
+
+@pytest.mark.parametrize(
+    "overrides, e_clip_acts, e_n_stds",
+    [
+        (None, ClipMode.AVG, 0),
+
+        (distiller.utils.yaml_ordered_load("""
+        rnn.cells.0.eltwisemult_hidden:
+            clip_acts: NONE
+        """), ClipMode.NONE, 0),
+
+        (distiller.utils.yaml_ordered_load("""
+        rnn.cells.0.eltwisemult_hidden:
+            clip_acts: N_STD
+            clip_n_stds: 2
+        """), ClipMode.N_STD, 2)
+    ]
+)
+def test_override_no_clip(overrides, e_clip_acts, e_n_stds, rnn_model, rnn_model_stats):
+    quantizer = PostTrainLinearQuantizer(rnn_model, clip_acts="AVG", clip_n_stds=0, overrides=overrides,
+                                         model_activation_stats=rnn_model_stats)
+    quantizer.prepare_model()
+    assert isinstance(quantizer.model.rnn.cells[0].eltwisemult_hidden, RangeLinearQuantEltwiseMultWrapper)
+    assert quantizer.model.rnn.cells[0].eltwisemult_hidden.clip_acts == e_clip_acts
+    assert quantizer.model.rnn.cells[0].eltwisemult_hidden.clip_n_stds == e_n_stds
-- 
GitLab