From ccd11ddbdc08e2ca544f352c66024a8959341830 Mon Sep 17 00:00:00 2001 From: Guy Jacob <guy.jacob@intel.com> Date: Mon, 17 Feb 2020 12:06:14 +0200 Subject: [PATCH] PyTorch PTQ convert updates/fixes + Raw activations collector * BUGFIX: Fixed wrong attribute name for zero-point in conversion of eltwise add/mult and concat * Add PyTorch PTQ convert for embedding (converted to FP32 embedding + quant op) * Fix conversion function to work with tuple/list model inputs --- distiller/data_loggers/collector.py | 74 +++++++++++++++++-- .../quantization/pytorch_quant_conversion.py | 13 +++- distiller/quantization/range_linear.py | 27 ++++++- 3 files changed, 98 insertions(+), 16 deletions(-) diff --git a/distiller/data_loggers/collector.py b/distiller/data_loggers/collector.py index b743072..8cc41aa 100755 --- a/distiller/data_loggers/collector.py +++ b/distiller/data_loggers/collector.py @@ -33,16 +33,18 @@ matplotlib.use('Agg') import matplotlib.pyplot as plt import distiller from distiller.quantization.range_linear import is_post_train_quant_wrapper +from distiller.quantization.pytorch_quant_conversion import QFunctionalWrapper import numpy as np +import concurrent.futures msglogger = logging.getLogger() -__all__ = ['SummaryActivationStatsCollector', 'RecordsActivationStatsCollector', - 'QuantCalibrationStatsCollector', 'ActivationHistogramsCollector', - 'CollectorDirection', - 'collect_quant_stats', 'collect_histograms', +__all__ = ['SummaryActivationStatsCollector', 'RecordsActivationStatsCollector', 'QuantCalibrationStatsCollector', + 'ActivationHistogramsCollector', 'RawActivationsCollector', 'CollectorDirection', + 'collect_quant_stats', 'collect_histograms', 'collect_raw_outputs', 'collector_context', 'collectors_context'] + class CollectorDirection(enum.Enum): OUT = 0 OFM = 0 @@ -169,7 +171,8 @@ class ActivationStatsCollector(object): # We make an exception for models that were quantized with 'PostTrainLinearQuantizer'. In these # models, the quantized modules are actually wrappers of the original FP32 modules, so they are # NOT leaf modules - but we still want to track them. - if distiller.has_children(module) and not is_post_train_quant_wrapper(module): + if distiller.has_children(module) and not (is_post_train_quant_wrapper(module) or + isinstance(module, QFunctionalWrapper)): return False if isinstance(module, torch.nn.Identity): return False @@ -216,7 +219,7 @@ class SummaryActivationStatsCollector(ActivationStatsCollector): inputs_consolidate_func is called on tuple of tensors, and returns a tensor. """ def __init__(self, model, stat_name, summary_fn, - classes=[torch.nn.ReLU, torch.nn.ReLU6, torch.nn.LeakyReLU], + classes=(torch.nn.ReLU, torch.nn.ReLU6, torch.nn.LeakyReLU), collector_direction=CollectorDirection.OUT, inputs_consolidate_func=torch.cat): super(SummaryActivationStatsCollector, self).__init__(model, stat_name, classes) @@ -302,9 +305,9 @@ class RecordsActivationStatsCollector(ActivationStatsCollector): For obvious reasons, this is slower than SummaryActivationStatsCollector. """ - def __init__(self, model, classes=[torch.nn.ReLU, + def __init__(self, model, classes=(torch.nn.ReLU, torch.nn.ReLU6, - torch.nn.LeakyReLU]): + torch.nn.LeakyReLU)): super(RecordsActivationStatsCollector, self).__init__(model, "statistics_records", classes) def _activation_stats_cb(self, module, inputs, output): @@ -798,6 +801,47 @@ class ActivationHistogramsCollector(ActivationStatsCollector): return fname +class RawActivationsCollector(ActivationStatsCollector): + def __init__(self, model, classes=None): + super(RawActivationsCollector, self).__init__(model, "raw_acts", classes) + + _verify_no_dataparallel(model) + + def _activation_stats_cb(self, module, inputs, output): + if isinstance(output, torch.Tensor): + if output.is_quantized: + module.raw_outputs.append(output.dequantize()) + else: + module.raw_outputs.append(output.cpu()) + + def _start_counter(self, module): + module.raw_outputs = [] + + def _reset_counter(self, module): + if hasattr(module, 'raw_outputs'): + module.raw_outputs = [] + + def _collect_activations_stats(self, module, activation_stats, name=''): + if not hasattr(module, 'raw_outputs'): + return + + if isinstance(module.raw_outputs, list) and len(module.raw_outputs) > 0: + module.raw_outputs = torch.stack(module.raw_outputs) + activation_stats[module.distiller_name] = module.raw_outputs + + def save(self, dir_name): + if not os.path.isdir(dir_name): + os.mkdir(dir_name) + + with concurrent.futures.ProcessPoolExecutor() as executor: + for idx, (layer_name, raw_outputs) in enumerate(self.value().items()): + idx_str = '{:03d}'.format(idx + 1) + executor.submit(torch.save, raw_outputs, os.path.join(dir_name, + '-'.join((idx_str, layer_name)) + '.pt')) + + return dir_name + + def collect_quant_stats(model, test_fn, save_dir=None, classes=None, inplace_runtime_check=False, disable_inplace_attrs=False, inplace_attr_names=('inplace',), modules_to_collect=None): @@ -893,6 +937,20 @@ def collect_histograms(model, test_fn, save_dir=None, activation_stats=None, return histogram_collector.value() +def collect_raw_outputs(model, test_fn, save_dir=None, classes=None): + msglogger.info('Collecting raw layer outputs for model') + collector = RawActivationsCollector(model, classes=classes) + with collector_context(collector): + test_fn(model=model) + msglogger.info('Outputs collection complete') + if save_dir is not None: + msglogger.info('Saving outputs to disk...') + save_path = os.path.join(save_dir, 'raw_outputs') + collector.save(save_path) + msglogger.info('Outputs saved to ' + save_path) + return collector.value() + + @contextmanager def collector_context(collector, modules_list=None): """A context manager for an activation collector""" diff --git a/distiller/quantization/pytorch_quant_conversion.py b/distiller/quantization/pytorch_quant_conversion.py index bda477a..0b8e5e5 100644 --- a/distiller/quantization/pytorch_quant_conversion.py +++ b/distiller/quantization/pytorch_quant_conversion.py @@ -140,7 +140,7 @@ def distiller_quantized_tensor_to_pytorch(tensor: torch.Tensor, scale, zp, num_b else: # dest_dtype == torch.qint32: temp_dtype = torch.int32 tensor = (tensor - zp_diff).to(temp_dtype) - if per_channel: + if per_channel and scale.shape[channel_dim] > 1: return torch._make_per_channel_quantized_tensor(tensor, converted_scale, converted_zp, channel_dim) return torch._make_per_tensor_quantized_tensor(tensor, converted_scale, converted_zp) @@ -168,6 +168,8 @@ def _ptq_convert_pass_replace_range_linear_wrappers(module): need_reduce_range(qset.quant_mode, torch.quint8)) d[idx] = (scale, zp, torch.quint8) new_m = ConditionalQuantizeWrapper(new_m, d) + elif isinstance(m, distiller.quantization.RangeLinearEmbeddingWrapper): + new_m = m.to_pytorch_quant(need_reduce_range(m.wts_quant_settings.quant_mode, torch.quint8)) elif distiller.has_children(m): new_m = _ptq_convert_pass_replace_range_linear_wrappers(m) elif not isinstance(m, nn.Identity): @@ -246,7 +248,10 @@ def _ptq_convert_pass_remove_redundant_quant_dequant(model, dummy_input): handles.append(m.register_forward_pre_hook(quantize_wrapper_check_hook)) elif isinstance(m, ConditionalDeQuantize): handles.append(m.register_forward_pre_hook(dequant_wrapper_check_hook)) - out = model(dummy_input) + if isinstance(dummy_input, torch.Tensor): + out = model(dummy_input) + else: + out = model(*dummy_input) for h in handles: h.remove() @@ -293,8 +298,8 @@ def convert_distiller_ptq_model_to_pytorch(model, dummy_input, backend='fbgemm') raise ValueError('Conversion to PyTorch native quantization supported only for models quantized ' 'using distiller.quantization.PostTrainLinearQuantizer') - if dummy_input is None or not isinstance(dummy_input, torch.Tensor): - raise ValueError('Valid dummy input tensor required for converting PTQ model to PyTorch') + if dummy_input is None: + raise ValueError('Valid dummy input required for converting PTQ model to PyTorch') backends = ('fbgemm', 'qnnpack') if backend not in backends: diff --git a/distiller/quantization/range_linear.py b/distiller/quantization/range_linear.py index 31bfc1d..3bb9e40 100644 --- a/distiller/quantization/range_linear.py +++ b/distiller/quantization/range_linear.py @@ -1166,7 +1166,7 @@ class RangeLinearQuantConcatWrapper(RangeLinearQuantWrapper): reduce_range) m = pytqc.QFunctionalCat(self.wrapped_module.dim) m.qfunc.scale = float(scale) - m.qfunc.zp = int(zp) + m.qfunc.zero_point = int(zp) if self.clip_half_range: # The scale factor calculated in Distiller already considers the ReLU, so it's OK to apply the # ReLU after quantization @@ -1217,7 +1217,7 @@ class RangeLinearQuantEltwiseAddWrapper(RangeLinearQuantWrapper): reduce_range) m = pytqc.QFunctionalAddRelu() if self.clip_half_range else pytqc.QFunctionalAdd() m.qfunc.scale = float(scale) - m.qfunc.zp = int(zp) + m.qfunc.zero_point = int(zp) return m @@ -1269,7 +1269,7 @@ class RangeLinearQuantEltwiseMultWrapper(RangeLinearQuantWrapper): reduce_range) m = pytqc.QFunctionalMul() m.qfunc.scale = float(scale) - m.qfunc.zp = int(zp) + m.qfunc.zero_point = int(zp) if self.clip_half_range: # The scale factor calculated in Distiller already considers the ReLU, so it's OK to apply the # ReLU after quantization @@ -1421,6 +1421,24 @@ class RangeLinearEmbeddingWrapper(nn.Module): out_f.quant_metadata = self.quant_metadata return out_f + def to_pytorch_quant(self, reduce_range): + # No quantized embedding in PyTorch, so use FP32 embedding followed by quantize + emb = deepcopy(self.wrapped_module) + with torch.no_grad(): + if self.save_fp_weights: + w_dq = nn.Parameter(self.float_weight, requires_grad=False) + else: + w_dq = nn.Parameter(linear_dequantize(emb.weight, self.w_scale, self.w_zero_point), + requires_grad=False) + emb.weight = w_dq + + scale, zp = pytqc.distiller_qparams_to_pytorch(self.w_scale, self.w_zero_point, + self.wts_quant_settings.num_bits, + self.wts_quant_settings.quant_mode, torch.quint8, + reduce_range) + + return nn.Sequential(emb, nnq.Quantize(scale, zp, torch.quint8)) + class RangeLinearFakeQuantWrapper(RangeLinearQuantWrapper): def __init__(self, wrapped_module, num_bits_acts, mode=LinearQuantMode.SYMMETRIC, clip_acts=ClipMode.NONE, @@ -1917,7 +1935,8 @@ class PostTrainLinearQuantizer(Quantizer): if self.linear_quant_params: out['linear_quant_params'] = lqp_dict = OrderedDict() for k, v in self.linear_quant_params.items(): # type: str, torch.Tensor - lqp_dict[k] = v.item() + if v.numel() == 1: + lqp_dict[k] = v.item() save_path = os.path.join(save_dir, 'layer_quant_params.yaml') distiller.yaml_ordered_save(save_path, out) -- GitLab