diff --git a/distiller/quantization/range_linear.py b/distiller/quantization/range_linear.py index 7e77c5778906a2a9f514033ede3932a444d153a4..a07c805feeee1a4a6644417d4a7ff45b31537303 100644 --- a/distiller/quantization/range_linear.py +++ b/distiller/quantization/range_linear.py @@ -935,11 +935,44 @@ class PostTrainLinearQuantizer(Quantizer): msglogger.info('Per-layer quantization parameters saved to ' + save_path) def prepare_model(self, dummy_input=None): - if dummy_input is None: + self.has_bidi_distiller_lstm = any(isinstance(m, distiller.modules.DistillerLSTM) and m.bidirectional for + _, m in self.model.named_modules()) + if self.has_bidi_distiller_lstm: + warnings.warn('Model contains a bidirectional DistillerLSTM module. ' + 'Automatic BN folding and statistics optimization based on tracing is not yet ' + 'supported for models containing such modules.\n' + 'Will perform specific optimization for the DistillerLSTM modules, but any other potential ' + 'opportunities for optimization in the model will be ignored.', UserWarning) + # Setting dummy_input to None to make sure SummaryGraph won't be called + dummy_input = None + elif dummy_input is None: raise ValueError('PostTrainLinearQuantizer requires dummy input in order to perform certain optimizations') super(PostTrainLinearQuantizer, self).prepare_model(dummy_input) def _pre_prepare_model(self, dummy_input): + if not self.has_bidi_distiller_lstm: + self._apply_bn_folding(dummy_input) + self._apply_activation_stats_fusions() + else: + self._apply_bidi_distiller_lstm_stats_fusion() + + if hasattr(msglogger, 'logdir'): + save_path = os.path.join(msglogger.logdir, 'quant_stats_after_prepare_model.yaml') + distiller.yaml_ordered_save(save_path, self.model_activation_stats) + msglogger.info('Updated stats saved to ' + save_path) + + def _clip_stats(self, entry, min_val, max_val): + if entry['max'] < min_val: + entry['min'] = entry['avg_min'] = entry['max'] = entry['avg_max'] = min_val + elif entry['min'] > max_val: + entry['min'] = entry['avg_min'] = entry['max'] = entry['avg_max'] = max_val + else: + entry['min'] = max(min_val, entry['min']) + entry['avg_min'] = max(min_val, entry['avg_min']) + entry['max'] = min(max_val, entry['max']) + entry['avg_max'] = min(max_val, entry['avg_max']) + + def _apply_bn_folding(self, dummy_input): msglogger.info('Applying batch-norm folding ahead of post-training quantization') mt.fold_batch_norms_inference(self.model, adjacency_map=self.adjacency_map) @@ -967,10 +1000,14 @@ class PostTrainLinearQuantizer(Quantizer): except (AttributeError, KeyError): continue + def _apply_activation_stats_fusions(self): # Now we look for certain "fusions" of layers and activations # We modify stats to make sure we quantize only the ranges relevant to the activation function # By doing so we reduce quantization error while still keeping all msglogger.info('Optimizing output statistics for modules followed by ReLU/Tanh/Sigmoid') + + named_modules = OrderedDict(self.model.named_modules()) + model_stats = self.model_activation_stats for n, m in named_modules.items(): if distiller.has_children(m) or n not in self.adjacency_map or len(self.adjacency_map[n].successors) != 1: continue @@ -995,17 +1032,6 @@ class PostTrainLinearQuantizer(Quantizer): succ_type = 'Sigmoid' succ_stats = None - def clip_stats(entry, min_val, max_val): - if entry['max'] < min_val: - entry['min'] = entry['avg_min'] = entry['max'] = entry['avg_max'] = min_val - elif entry['min'] > max_val: - entry['min'] = entry['avg_min'] = entry['max'] = entry['avg_max'] = max_val - else: - entry['min'] = max(min_val, entry['min']) - entry['avg_min'] = max(min_val, entry['avg_min']) - entry['max'] = min(max_val, entry['max']) - entry['avg_max'] = min(max_val, entry['avg_max']) - if succ_type == 'Relu': # ReLU zeros out all negative values, so there's no need to quantize them msglogger.debug(' Module {} followed by Relu, updating stats'.format(n)) @@ -1014,20 +1040,25 @@ class PostTrainLinearQuantizer(Quantizer): succ_stats['inputs'][0] = deepcopy(succ_stats['output']) else: msglogger.debug(" Relu op not a module or post-split, can't update mean and std".format(n)) - clip_stats(m_stats['output'], 0., m_stats['output']['max']) + self._clip_stats(m_stats['output'], 0., m_stats['output']['max']) elif succ_type == 'Sigmoid' or succ_type == 'Tanh': # Tanh / Sigmoid saturate at ~4 / ~6 respectively. No need to quantize their inputs outside # of these ranges msglogger.debug(' Module {} followed by {}, updating stats'.format(n, succ_type)) sat_val = 4. if succ_type == 'Tanh' else 6. - clip_stats(m_stats['output'], -sat_val, sat_val) + self._clip_stats(m_stats['output'], -sat_val, sat_val) if succ_stats is not None: succ_stats['inputs'][0] = deepcopy(m_stats['output']) - if hasattr(msglogger, 'logdir'): - save_path = os.path.join(msglogger.logdir, 'quant_stats_after_prepare_model.yaml') - distiller.yaml_ordered_save(save_path, self.model_activation_stats) - msglogger.info('Updated stats saved to ' + save_path) + def _apply_bidi_distiller_lstm_stats_fusion(self): + distiller_lstm_cells = [n for n, m in self.model.named_modules() if + isinstance(m, distiller.modules.DistillerLSTMCell)] + + for name in distiller_lstm_cells: + name += '.eltwiseadd_gate' + msglogger.debug(' Module {} followed by Sigmoid, updating stats'.format(name)) + sat_val = 6. + self._clip_stats(self.model_activation_stats[name]['output'], -sat_val, sat_val) ###############################################################################