From ccd11ddbdc08e2ca544f352c66024a8959341830 Mon Sep 17 00:00:00 2001
From: Guy Jacob <guy.jacob@intel.com>
Date: Mon, 17 Feb 2020 12:06:14 +0200
Subject: [PATCH] PyTorch PTQ convert updates/fixes + Raw activations collector

* BUGFIX: Fixed wrong attribute name for zero-point in conversion
  of eltwise add/mult and concat
* Add PyTorch PTQ convert for embedding (converted to FP32
  embedding + quant op)
* Fix conversion function to work with tuple/list model inputs
---
 distiller/data_loggers/collector.py           | 74 +++++++++++++++++--
 .../quantization/pytorch_quant_conversion.py  | 13 +++-
 distiller/quantization/range_linear.py        | 27 ++++++-
 3 files changed, 98 insertions(+), 16 deletions(-)

diff --git a/distiller/data_loggers/collector.py b/distiller/data_loggers/collector.py
index b743072..8cc41aa 100755
--- a/distiller/data_loggers/collector.py
+++ b/distiller/data_loggers/collector.py
@@ -33,16 +33,18 @@ matplotlib.use('Agg')
 import matplotlib.pyplot as plt
 import distiller
 from distiller.quantization.range_linear import is_post_train_quant_wrapper
+from distiller.quantization.pytorch_quant_conversion import QFunctionalWrapper
 import numpy as np
+import concurrent.futures
 
 msglogger = logging.getLogger()
 
-__all__ = ['SummaryActivationStatsCollector', 'RecordsActivationStatsCollector',
-           'QuantCalibrationStatsCollector', 'ActivationHistogramsCollector',
-           'CollectorDirection',
-           'collect_quant_stats', 'collect_histograms',
+__all__ = ['SummaryActivationStatsCollector', 'RecordsActivationStatsCollector', 'QuantCalibrationStatsCollector',
+           'ActivationHistogramsCollector', 'RawActivationsCollector', 'CollectorDirection',
+           'collect_quant_stats', 'collect_histograms', 'collect_raw_outputs',
            'collector_context', 'collectors_context']
 
+
 class CollectorDirection(enum.Enum):
     OUT = 0
     OFM = 0
@@ -169,7 +171,8 @@ class ActivationStatsCollector(object):
         # We make an exception for models that were quantized with 'PostTrainLinearQuantizer'. In these
         # models, the quantized modules are actually wrappers of the original FP32 modules, so they are
         # NOT leaf modules - but we still want to track them.
-        if distiller.has_children(module) and not is_post_train_quant_wrapper(module):
+        if distiller.has_children(module) and not (is_post_train_quant_wrapper(module) or
+                                                   isinstance(module, QFunctionalWrapper)):
             return False
         if isinstance(module, torch.nn.Identity):
             return False
@@ -216,7 +219,7 @@ class SummaryActivationStatsCollector(ActivationStatsCollector):
     inputs_consolidate_func is called on tuple of tensors, and returns a tensor.
     """
     def __init__(self, model, stat_name, summary_fn,
-                 classes=[torch.nn.ReLU, torch.nn.ReLU6, torch.nn.LeakyReLU],
+                 classes=(torch.nn.ReLU, torch.nn.ReLU6, torch.nn.LeakyReLU),
                  collector_direction=CollectorDirection.OUT,
                  inputs_consolidate_func=torch.cat):
         super(SummaryActivationStatsCollector, self).__init__(model, stat_name, classes)
@@ -302,9 +305,9 @@ class RecordsActivationStatsCollector(ActivationStatsCollector):
 
     For obvious reasons, this is slower than SummaryActivationStatsCollector.
     """
-    def __init__(self, model, classes=[torch.nn.ReLU,
+    def __init__(self, model, classes=(torch.nn.ReLU,
                                        torch.nn.ReLU6,
-                                       torch.nn.LeakyReLU]):
+                                       torch.nn.LeakyReLU)):
         super(RecordsActivationStatsCollector, self).__init__(model, "statistics_records", classes)
 
     def _activation_stats_cb(self, module, inputs, output):
@@ -798,6 +801,47 @@ class ActivationHistogramsCollector(ActivationStatsCollector):
         return fname
 
 
+class RawActivationsCollector(ActivationStatsCollector):
+    def __init__(self, model, classes=None):
+        super(RawActivationsCollector, self).__init__(model, "raw_acts", classes)
+
+        _verify_no_dataparallel(model)
+
+    def _activation_stats_cb(self, module, inputs, output):
+        if isinstance(output, torch.Tensor):
+            if output.is_quantized:
+                module.raw_outputs.append(output.dequantize())
+            else:
+                module.raw_outputs.append(output.cpu())
+
+    def _start_counter(self, module):
+        module.raw_outputs = []
+
+    def _reset_counter(self, module):
+        if hasattr(module, 'raw_outputs'):
+            module.raw_outputs = []
+
+    def _collect_activations_stats(self, module, activation_stats, name=''):
+        if not hasattr(module, 'raw_outputs'):
+            return
+
+        if isinstance(module.raw_outputs, list) and len(module.raw_outputs) > 0:
+            module.raw_outputs = torch.stack(module.raw_outputs)
+        activation_stats[module.distiller_name] = module.raw_outputs
+
+    def save(self, dir_name):
+        if not os.path.isdir(dir_name):
+            os.mkdir(dir_name)
+
+        with concurrent.futures.ProcessPoolExecutor() as executor:
+            for idx, (layer_name, raw_outputs) in enumerate(self.value().items()):
+                idx_str = '{:03d}'.format(idx + 1)
+                executor.submit(torch.save, raw_outputs, os.path.join(dir_name,
+                                                                      '-'.join((idx_str, layer_name)) + '.pt'))
+
+        return dir_name
+
+
 def collect_quant_stats(model, test_fn, save_dir=None, classes=None, inplace_runtime_check=False,
                         disable_inplace_attrs=False, inplace_attr_names=('inplace',),
                         modules_to_collect=None):
@@ -893,6 +937,20 @@ def collect_histograms(model, test_fn, save_dir=None, activation_stats=None,
     return histogram_collector.value()
 
 
+def collect_raw_outputs(model, test_fn, save_dir=None, classes=None):
+    msglogger.info('Collecting raw layer outputs for model')
+    collector = RawActivationsCollector(model, classes=classes)
+    with collector_context(collector):
+        test_fn(model=model)
+    msglogger.info('Outputs collection complete')
+    if save_dir is not None:
+        msglogger.info('Saving outputs to disk...')
+        save_path = os.path.join(save_dir, 'raw_outputs')
+        collector.save(save_path)
+        msglogger.info('Outputs saved to ' + save_path)
+    return collector.value()
+
+
 @contextmanager
 def collector_context(collector, modules_list=None):
     """A context manager for an activation collector"""
diff --git a/distiller/quantization/pytorch_quant_conversion.py b/distiller/quantization/pytorch_quant_conversion.py
index bda477a..0b8e5e5 100644
--- a/distiller/quantization/pytorch_quant_conversion.py
+++ b/distiller/quantization/pytorch_quant_conversion.py
@@ -140,7 +140,7 @@ def distiller_quantized_tensor_to_pytorch(tensor: torch.Tensor, scale, zp, num_b
     else:  # dest_dtype == torch.qint32:
         temp_dtype = torch.int32
     tensor = (tensor - zp_diff).to(temp_dtype)
-    if per_channel:
+    if per_channel and scale.shape[channel_dim] > 1:
         return torch._make_per_channel_quantized_tensor(tensor, converted_scale, converted_zp, channel_dim)
     return torch._make_per_tensor_quantized_tensor(tensor, converted_scale, converted_zp)
 
@@ -168,6 +168,8 @@ def _ptq_convert_pass_replace_range_linear_wrappers(module):
                                                              need_reduce_range(qset.quant_mode, torch.quint8))
                     d[idx] = (scale, zp, torch.quint8)
                 new_m = ConditionalQuantizeWrapper(new_m, d)
+        elif isinstance(m, distiller.quantization.RangeLinearEmbeddingWrapper):
+            new_m = m.to_pytorch_quant(need_reduce_range(m.wts_quant_settings.quant_mode, torch.quint8))
         elif distiller.has_children(m):
             new_m = _ptq_convert_pass_replace_range_linear_wrappers(m)
         elif not isinstance(m, nn.Identity):
@@ -246,7 +248,10 @@ def _ptq_convert_pass_remove_redundant_quant_dequant(model, dummy_input):
             handles.append(m.register_forward_pre_hook(quantize_wrapper_check_hook))
         elif isinstance(m, ConditionalDeQuantize):
             handles.append(m.register_forward_pre_hook(dequant_wrapper_check_hook))
-    out = model(dummy_input)
+    if isinstance(dummy_input, torch.Tensor):
+        out = model(dummy_input)
+    else:
+        out = model(*dummy_input)
     for h in handles:
         h.remove()
 
@@ -293,8 +298,8 @@ def convert_distiller_ptq_model_to_pytorch(model, dummy_input, backend='fbgemm')
         raise ValueError('Conversion to PyTorch native quantization supported only for models quantized '
                          'using distiller.quantization.PostTrainLinearQuantizer')
 
-    if dummy_input is None or not isinstance(dummy_input, torch.Tensor):
-        raise ValueError('Valid dummy input tensor required for converting PTQ model to PyTorch')
+    if dummy_input is None:
+        raise ValueError('Valid dummy input required for converting PTQ model to PyTorch')
 
     backends = ('fbgemm', 'qnnpack')
     if backend not in backends:
diff --git a/distiller/quantization/range_linear.py b/distiller/quantization/range_linear.py
index 31bfc1d..3bb9e40 100644
--- a/distiller/quantization/range_linear.py
+++ b/distiller/quantization/range_linear.py
@@ -1166,7 +1166,7 @@ class RangeLinearQuantConcatWrapper(RangeLinearQuantWrapper):
                                                        reduce_range)
         m = pytqc.QFunctionalCat(self.wrapped_module.dim)
         m.qfunc.scale = float(scale)
-        m.qfunc.zp = int(zp)
+        m.qfunc.zero_point = int(zp)
         if self.clip_half_range:
             # The scale factor calculated in Distiller already considers the ReLU, so it's OK to apply the
             # ReLU after quantization
@@ -1217,7 +1217,7 @@ class RangeLinearQuantEltwiseAddWrapper(RangeLinearQuantWrapper):
                                                        reduce_range)
         m = pytqc.QFunctionalAddRelu() if self.clip_half_range else pytqc.QFunctionalAdd()
         m.qfunc.scale = float(scale)
-        m.qfunc.zp = int(zp)
+        m.qfunc.zero_point = int(zp)
         return m
 
 
@@ -1269,7 +1269,7 @@ class RangeLinearQuantEltwiseMultWrapper(RangeLinearQuantWrapper):
                                                        reduce_range)
         m = pytqc.QFunctionalMul()
         m.qfunc.scale = float(scale)
-        m.qfunc.zp = int(zp)
+        m.qfunc.zero_point = int(zp)
         if self.clip_half_range:
             # The scale factor calculated in Distiller already considers the ReLU, so it's OK to apply the
             # ReLU after quantization
@@ -1421,6 +1421,24 @@ class RangeLinearEmbeddingWrapper(nn.Module):
         out_f.quant_metadata = self.quant_metadata
         return out_f
 
+    def to_pytorch_quant(self, reduce_range):
+        # No quantized embedding in PyTorch, so use FP32 embedding followed by quantize
+        emb = deepcopy(self.wrapped_module)
+        with torch.no_grad():
+            if self.save_fp_weights:
+                w_dq = nn.Parameter(self.float_weight, requires_grad=False)
+            else:
+                w_dq = nn.Parameter(linear_dequantize(emb.weight, self.w_scale, self.w_zero_point),
+                                    requires_grad=False)
+        emb.weight = w_dq
+
+        scale, zp = pytqc.distiller_qparams_to_pytorch(self.w_scale, self.w_zero_point,
+                                                       self.wts_quant_settings.num_bits,
+                                                       self.wts_quant_settings.quant_mode, torch.quint8,
+                                                       reduce_range)
+
+        return nn.Sequential(emb, nnq.Quantize(scale, zp, torch.quint8))
+
 
 class RangeLinearFakeQuantWrapper(RangeLinearQuantWrapper):
     def __init__(self, wrapped_module, num_bits_acts, mode=LinearQuantMode.SYMMETRIC, clip_acts=ClipMode.NONE,
@@ -1917,7 +1935,8 @@ class PostTrainLinearQuantizer(Quantizer):
         if self.linear_quant_params:
             out['linear_quant_params'] = lqp_dict = OrderedDict()
             for k, v in self.linear_quant_params.items():  # type: str, torch.Tensor
-                lqp_dict[k] = v.item()
+                if v.numel() == 1:
+                    lqp_dict[k] = v.item()
 
         save_path = os.path.join(save_dir, 'layer_quant_params.yaml')
         distiller.yaml_ordered_save(save_path, out)
-- 
GitLab