diff --git a/distiller/data_loggers/collector.py b/distiller/data_loggers/collector.py index d3f98cb79eaefb0d7f13434342d6a68741155cf1..f561a2ba19288730a6cb98eb61e2172d4a145591 100755 --- a/distiller/data_loggers/collector.py +++ b/distiller/data_loggers/collector.py @@ -30,6 +30,7 @@ import matplotlib matplotlib.use('Agg') import matplotlib.pyplot as plt import distiller +from distiller.quantization.range_linear import is_post_train_quant_wrapper import numpy as np msglogger = logging.getLogger() @@ -78,6 +79,13 @@ class ActivationStatsCollector(object): # a unique, human-readable name per layer. distiller.utils.assign_layer_fq_names(model) + # Currently this is internal, and its only purpose is to enable skipping collection + # for wrapped modules inside post-training quantization wrapper classes. + # When doing PTQ, the outputs of these wrapped modules are actually intermediate results + # which are not relevant for tracking. + self._dont_collect_list = [module.wrapped_module.distiller_name for module in model.modules() if + is_post_train_quant_wrapper(module)] + def value(self): """Return a dictionary containing {layer_name: statistic}""" activation_stats = OrderedDict() @@ -104,10 +112,7 @@ class ActivationStatsCollector(object): Eligible modules are currently filtered by their class type. """ - if distiller.has_children(module) or isinstance(module, torch.nn.Identity): - return - register_all_class_types = not self.classes - if register_all_class_types or isinstance(module, tuple(self.classes)): + if self._should_collect(module): self.fwd_hook_handles.append(module.register_forward_hook(self._activation_stats_cb)) self._start_counter(module) @@ -147,6 +152,24 @@ class ActivationStatsCollector(object): """Handle new activations - this is subclass-specific code""" raise NotImplementedError + def _should_collect(self, module): + if module.distiller_name in self._dont_collect_list: + return False + # In general, we only collect stats for "leaf" modules. + # We make an exception for models that were quantized with 'PostTrainLinearQuantizer'. In these + # models, the quantized modules are actually wrappers of the original FP32 modules, so they are + # NOT leaf modules - but we still want to track them. + if distiller.has_children(module) and not is_post_train_quant_wrapper(module): + return False + if isinstance(module, torch.nn.Identity): + return False + + register_all_class_types = not self.classes + if register_all_class_types or isinstance(module, tuple(self.classes)): + return True + + return False + class WeightedAverageValueMeter(AverageValueMeter): """ @@ -173,7 +196,7 @@ class WeightedAverageValueMeter(AverageValueMeter): class SummaryActivationStatsCollector(ActivationStatsCollector): - """This class collects activiations statistical summaries. + """This class collects activations statistical summaries. This Collector computes the mean of some statistic of the activation. It is rather light-weight and quicker than collecting a record per activation. @@ -445,7 +468,7 @@ class QuantCalibrationStatsCollector(ActivationStatsCollector): Check whether the required statistics were collected to allow collecting laplace distribution stats. """ for name, module in self.model.named_modules(): - if distiller.has_children(module) or isinstance(module, torch.nn.Identity): + if not self._should_collect(module): continue if not hasattr(module, 'quant_stats'): raise RuntimeError('Collection of Laplace distribution statistics is ' @@ -463,9 +486,9 @@ class QuantCalibrationStatsCollector(ActivationStatsCollector): def start_second_pass(self): self._check_required_stats() self.collecting_second_pass = True - # reset batch_idx for all leaf modules + # reset batch_idx for all tracked modules for module in self.model.modules(): - if distiller.has_children(module) or isinstance(module, torch.nn.Identity): + if not self._should_collect(module): continue module.batch_idx = 0 for record in module.quant_stats.inputs: @@ -571,8 +594,6 @@ class QuantCalibrationStatsCollector(ActivationStatsCollector): module.batch_idx = 0 def _collect_activations_stats(self, module, activation_stats, name=''): - if distiller.utils.has_children(module): - return if not hasattr(module, 'quant_stats'): return @@ -688,8 +709,6 @@ class ActivationHistogramsCollector(ActivationStatsCollector): self._reset(module) def _collect_activations_stats(self, module, activation_stats, name=''): - if distiller.utils.has_children(module): - return if not hasattr(module, 'output_hist'): return diff --git a/distiller/quantization/range_linear.py b/distiller/quantization/range_linear.py index a5e03b4b3f131c18aa210cd0f5418e8e4297e068..87f85466a54ef1f67b1f435edaf8c62e57d3384b 100644 --- a/distiller/quantization/range_linear.py +++ b/distiller/quantization/range_linear.py @@ -674,7 +674,7 @@ class RangeLinearQuantParamLayerWrapper(RangeLinearQuantWrapper): self.is_simulated_quant_weight_shifted.fill_(True) # i.e. is_simulated_quant_weight_shifted = True input_q += input_q.quant_metadata.zero_point - accum = self.wrapped_module.forward(input_q) + accum = self.wrapped_module(input_q) clamp(accum.data, self.accum_min_q_val, self.accum_max_q_val, inplace=True) return accum @@ -749,8 +749,8 @@ class RangeLinearQuantMatmulWrapper(RangeLinearQuantWrapper): def quantized_forward(self, input0_q, input1_q): self.accum_scale = input0_q.quant_metadata.scale * input1_q.quant_metadata.scale - accum = self.wrapped_module.forward(input0_q + input0_q.quant_metadata.zero_point, - input1_q + input1_q.quant_metadata.zero_point) + accum = self.wrapped_module(input0_q + input0_q.quant_metadata.zero_point, + input1_q + input1_q.quant_metadata.zero_point) clamp(accum.data, self.accum_min_q_val, self.accum_max_q_val, inplace=True) return accum @@ -994,8 +994,13 @@ class RangeLinearFakeQuantWrapper(RangeLinearQuantWrapper): return output_scale, output_zero_point -def _is_range_linear_wrapper(module): - return isinstance(module, (RangeLinearEmbeddingWrapper, RangeLinearQuantWrapper)) +_ptq_wrappers_int_only = (RangeLinearQuantWrapper, RangeLinearEmbeddingWrapper) +_ptq_wrappers_all = _ptq_wrappers_int_only + (FPWrapper,) + + +def is_post_train_quant_wrapper(module, include_fpwrapper=True): + types = _ptq_wrappers_all if include_fpwrapper else _ptq_wrappers_int_only + return isinstance(module, types) class PostTrainLinearQuantizer(Quantizer): @@ -1232,7 +1237,7 @@ class PostTrainLinearQuantizer(Quantizer): def named_acts_quant_params(self): for module_name, module in self.model.named_modules(): - if _is_range_linear_wrapper(module): + if is_post_train_quant_wrapper(module, include_fpwrapper=False): for buff_name, buff in module.named_acts_quant_params(): full_buff_name = "%s.%s" % (module_name, buff_name) yield full_buff_name, buff