diff --git a/distiller/data_loggers/collector.py b/distiller/data_loggers/collector.py index 68fb60787e732db183947a2d7ede993bebef0d2f..1e8d0c3559b7059c0ab96dd39f355ec8af213da7 100755 --- a/distiller/data_loggers/collector.py +++ b/distiller/data_loggers/collector.py @@ -475,10 +475,6 @@ class QuantCalibrationStatsCollector(ActivationStatsCollector): activation_stats[module.distiller_name]['output'] = module.quant_stats.output def save(self, fname): - def ordered_dict_representer(self, value): - return self.represent_mapping('tag:yaml.org,2002:map', value.items()) - yaml.add_representer(OrderedDict, ordered_dict_representer) - if not fname.endswith('.yaml'): fname = ".".join([fname, 'yaml']) try: @@ -487,8 +483,7 @@ class QuantCalibrationStatsCollector(ActivationStatsCollector): pass records_dict = self.value() - with open(fname, 'w') as f: - yaml.dump(records_dict, f, default_flow_style=False) + distiller.yaml_ordered_save(fname, records_dict) return fname diff --git a/distiller/quantization/range_linear.py b/distiller/quantization/range_linear.py index 58fbe59b754fa17dc633cfbdf3f28cf7de84503c..4dd36ed94f324f5ce2d0c11f4a089c69f956e9d2 100644 --- a/distiller/quantization/range_linear.py +++ b/distiller/quantization/range_linear.py @@ -46,7 +46,7 @@ class ClipMode(Enum): NONE = 0 # Clip value calculated by averaging over the max absolute values of samples within a batch AVG = 1 - # Clip value calculated as mean of tesnsor + N standard deviations. N should be specified separately + # Clip value calculated as mean of tensor + N standard deviations. N should be specified separately N_STD = 2 @@ -872,6 +872,8 @@ class PostTrainLinearQuantizer(Quantizer): replace_non_param_layer, RangeLinearQuantEltwiseMultWrapper) self.replacement_factory[nn.Embedding] = replace_embedding + self.save_per_layer_parameters(msglogger.logdir) + @classmethod def from_args(cls, model, args): """ @@ -898,6 +900,28 @@ class PostTrainLinearQuantizer(Quantizer): scale_approx_mult_bits=args.qe_scale_approx_bits, overrides=overrides) + def save_per_layer_parameters(self, save_dir=''): + defaults = OrderedDict(self.model.quantizer_metadata['params']) + defaults.pop('bits_activations') + defaults.pop('bits_parameters') + defaults.pop('bits_accum') + out = OrderedDict() + for n, m in self.model.named_modules(): + if distiller.has_children(m): + continue + qbits = self.module_qbits_map[n] + d = OrderedDict() + d['bits_activations'] = qbits.acts + d['bits_weights'] = qbits.wts + d['bits_bias'] = qbits.bias + for k, v in defaults.items(): + actual_v = self.module_overrides_map[n].get(k, v) + d[k] = actual_v + out[n] = d + save_path = os.path.join(save_dir, 'layer_quant_params.yaml') + distiller.yaml_ordered_save(save_path, out) + msglogger.info('Per-layer quantization parameters saved to ' + save_path) + ############################################################################### # Quantization-aware training diff --git a/distiller/utils.py b/distiller/utils.py index 996d0f5168a28a8cdd96cb0048231ddd428d6519..9532a55feea604eade412a0a7e49c30bc93bbbb7 100755 --- a/distiller/utils.py +++ b/distiller/utils.py @@ -686,6 +686,16 @@ def yaml_ordered_load(stream, Loader=yaml.Loader, object_pairs_hook=OrderedDict) return yaml.load(stream, OrderedLoader) +def yaml_ordered_save(fname, ordered_dict): + def ordered_dict_representer(self, value): + return self.represent_mapping('tag:yaml.org,2002:map', value.items()) + + yaml.add_representer(OrderedDict, ordered_dict_representer) + + with open(fname, 'w') as f: + yaml.dump(ordered_dict, f, default_flow_style=False) + + def float_range_argparse_checker(min_val=0., max_val=1., exc_min=False, exc_max=False): def checker(val_str): val = float(val_str)