From 69b1452a9102f138fc655025c9db8e5c99bb96c6 Mon Sep 17 00:00:00 2001 From: Guy Jacob <guy.jacob@intel.com> Date: Mon, 29 Jul 2019 12:09:02 +0300 Subject: [PATCH] Post-train quant: Special handling for bidirectional DistillerLSTM (#337) * For some reason, SummaryGraph generation is broken for DistillerLSTM modules with 'bidirectional' enabled. The ONNX graph optimization stage causes all the nodes from the bidirectional module to vanish from the graph (they're in the graph after the initial trace) * As a temporary workaround to enable stats fusion in post-train quant, if a bidirectional DistillerLSTM is detected, we just do a simple "hard-coded" fusion of the element-wise add op with the subsequent non-linearities and skip the automatic flow with SummaryGraph. --- distiller/quantization/range_linear.py | 67 +++++++++++++++++++------- 1 file changed, 49 insertions(+), 18 deletions(-) diff --git a/distiller/quantization/range_linear.py b/distiller/quantization/range_linear.py index 7e77c57..a07c805 100644 --- a/distiller/quantization/range_linear.py +++ b/distiller/quantization/range_linear.py @@ -935,11 +935,44 @@ class PostTrainLinearQuantizer(Quantizer): msglogger.info('Per-layer quantization parameters saved to ' + save_path) def prepare_model(self, dummy_input=None): - if dummy_input is None: + self.has_bidi_distiller_lstm = any(isinstance(m, distiller.modules.DistillerLSTM) and m.bidirectional for + _, m in self.model.named_modules()) + if self.has_bidi_distiller_lstm: + warnings.warn('Model contains a bidirectional DistillerLSTM module. ' + 'Automatic BN folding and statistics optimization based on tracing is not yet ' + 'supported for models containing such modules.\n' + 'Will perform specific optimization for the DistillerLSTM modules, but any other potential ' + 'opportunities for optimization in the model will be ignored.', UserWarning) + # Setting dummy_input to None to make sure SummaryGraph won't be called + dummy_input = None + elif dummy_input is None: raise ValueError('PostTrainLinearQuantizer requires dummy input in order to perform certain optimizations') super(PostTrainLinearQuantizer, self).prepare_model(dummy_input) def _pre_prepare_model(self, dummy_input): + if not self.has_bidi_distiller_lstm: + self._apply_bn_folding(dummy_input) + self._apply_activation_stats_fusions() + else: + self._apply_bidi_distiller_lstm_stats_fusion() + + if hasattr(msglogger, 'logdir'): + save_path = os.path.join(msglogger.logdir, 'quant_stats_after_prepare_model.yaml') + distiller.yaml_ordered_save(save_path, self.model_activation_stats) + msglogger.info('Updated stats saved to ' + save_path) + + def _clip_stats(self, entry, min_val, max_val): + if entry['max'] < min_val: + entry['min'] = entry['avg_min'] = entry['max'] = entry['avg_max'] = min_val + elif entry['min'] > max_val: + entry['min'] = entry['avg_min'] = entry['max'] = entry['avg_max'] = max_val + else: + entry['min'] = max(min_val, entry['min']) + entry['avg_min'] = max(min_val, entry['avg_min']) + entry['max'] = min(max_val, entry['max']) + entry['avg_max'] = min(max_val, entry['avg_max']) + + def _apply_bn_folding(self, dummy_input): msglogger.info('Applying batch-norm folding ahead of post-training quantization') mt.fold_batch_norms_inference(self.model, adjacency_map=self.adjacency_map) @@ -967,10 +1000,14 @@ class PostTrainLinearQuantizer(Quantizer): except (AttributeError, KeyError): continue + def _apply_activation_stats_fusions(self): # Now we look for certain "fusions" of layers and activations # We modify stats to make sure we quantize only the ranges relevant to the activation function # By doing so we reduce quantization error while still keeping all msglogger.info('Optimizing output statistics for modules followed by ReLU/Tanh/Sigmoid') + + named_modules = OrderedDict(self.model.named_modules()) + model_stats = self.model_activation_stats for n, m in named_modules.items(): if distiller.has_children(m) or n not in self.adjacency_map or len(self.adjacency_map[n].successors) != 1: continue @@ -995,17 +1032,6 @@ class PostTrainLinearQuantizer(Quantizer): succ_type = 'Sigmoid' succ_stats = None - def clip_stats(entry, min_val, max_val): - if entry['max'] < min_val: - entry['min'] = entry['avg_min'] = entry['max'] = entry['avg_max'] = min_val - elif entry['min'] > max_val: - entry['min'] = entry['avg_min'] = entry['max'] = entry['avg_max'] = max_val - else: - entry['min'] = max(min_val, entry['min']) - entry['avg_min'] = max(min_val, entry['avg_min']) - entry['max'] = min(max_val, entry['max']) - entry['avg_max'] = min(max_val, entry['avg_max']) - if succ_type == 'Relu': # ReLU zeros out all negative values, so there's no need to quantize them msglogger.debug(' Module {} followed by Relu, updating stats'.format(n)) @@ -1014,20 +1040,25 @@ class PostTrainLinearQuantizer(Quantizer): succ_stats['inputs'][0] = deepcopy(succ_stats['output']) else: msglogger.debug(" Relu op not a module or post-split, can't update mean and std".format(n)) - clip_stats(m_stats['output'], 0., m_stats['output']['max']) + self._clip_stats(m_stats['output'], 0., m_stats['output']['max']) elif succ_type == 'Sigmoid' or succ_type == 'Tanh': # Tanh / Sigmoid saturate at ~4 / ~6 respectively. No need to quantize their inputs outside # of these ranges msglogger.debug(' Module {} followed by {}, updating stats'.format(n, succ_type)) sat_val = 4. if succ_type == 'Tanh' else 6. - clip_stats(m_stats['output'], -sat_val, sat_val) + self._clip_stats(m_stats['output'], -sat_val, sat_val) if succ_stats is not None: succ_stats['inputs'][0] = deepcopy(m_stats['output']) - if hasattr(msglogger, 'logdir'): - save_path = os.path.join(msglogger.logdir, 'quant_stats_after_prepare_model.yaml') - distiller.yaml_ordered_save(save_path, self.model_activation_stats) - msglogger.info('Updated stats saved to ' + save_path) + def _apply_bidi_distiller_lstm_stats_fusion(self): + distiller_lstm_cells = [n for n, m in self.model.named_modules() if + isinstance(m, distiller.modules.DistillerLSTMCell)] + + for name in distiller_lstm_cells: + name += '.eltwiseadd_gate' + msglogger.debug(' Module {} followed by Sigmoid, updating stats'.format(name)) + sat_val = 6. + self._clip_stats(self.model_activation_stats[name]['output'], -sat_val, sat_val) ############################################################################### -- GitLab