diff --git a/distiller/data_loggers/collector.py b/distiller/data_loggers/collector.py index 7fef7487a45363ad314567d95413a849fe700e41..a7621d9d3fb33d848a8d4a2ec6a2e52fd9284d41 100755 --- a/distiller/data_loggers/collector.py +++ b/distiller/data_loggers/collector.py @@ -30,6 +30,8 @@ import matplotlib matplotlib.use('Agg') import matplotlib.pyplot as plt import distiller +import numpy as np + msglogger = logging.getLogger() __all__ = ['SummaryActivationStatsCollector', 'RecordsActivationStatsCollector', @@ -146,6 +148,30 @@ class ActivationStatsCollector(object): raise NotImplementedError +class WeightedAverageValueMeter(AverageValueMeter): + """ + A correction to torchnet's AverageValueMeter which doesn't implement + std collection correctly by taking into account the batch size. + """ + def add(self, value, n=1): + self.sum += value*n + if n <= 0: + raise ValueError("Cannot use a non-positive weight for the running stat.") + elif self.n == 0: + self.mean = 0.0 + value # This is to force a copy in torch/numpy + self.std = np.inf + self.mean_old = self.mean + self.m_s = 0.0 + else: + self.mean = self.mean_old + n * (value - self.mean_old) / float(self.n+n) + self.m_s += n*(value - self.mean_old) * (value - self.mean) + self.mean_old = self.mean + self.std = np.sqrt(self.m_s / (self.n + n - 1.0)) + self.var = self.std**2 + + self.n += n + + class SummaryActivationStatsCollector(ActivationStatsCollector): """This class collects activiations statistical summaries. @@ -165,7 +191,7 @@ class SummaryActivationStatsCollector(ActivationStatsCollector): This is a callback from the forward() of 'module'. """ try: - getattr(module, self.stat_name).add(self.summary_fn(output.data)) + getattr(module, self.stat_name).add(self.summary_fn(output.data), output.data.numel()) except RuntimeError as e: if "The expanded size of the tensor" in e.args[0]: raise ValueError("ActivationStatsCollector: a module ({} - {}) was encountered twice during model.apply().\n" @@ -181,7 +207,7 @@ class SummaryActivationStatsCollector(ActivationStatsCollector): def _start_counter(self, module): if not hasattr(module, self.stat_name): - setattr(module, self.stat_name, AverageValueMeter()) + setattr(module, self.stat_name, WeightedAverageValueMeter()) # Assign a name to this summary if hasattr(module, 'distiller_name'): getattr(module, self.stat_name).name = '_'.join((self.stat_name, module.distiller_name)) @@ -323,6 +349,7 @@ class _QuantStatsRecord(object): for stat_name in ['avg_min', 'avg_max', 'mean', 'std', 'b']: records[stat_name] = 0 records['shape'] = '' + records['total_numel'] = 0 return records def __init__(self): @@ -395,7 +422,7 @@ class QuantCalibrationStatsCollector(ActivationStatsCollector): self.batch_idx = 0 self.inplace_runtime_check = inplace_runtime_check - self.collecting_laplace = False + self.collecting_second_pass = False if disable_inplace_attrs: if not inplace_attr_names: @@ -425,45 +452,65 @@ class QuantCalibrationStatsCollector(ActivationStatsCollector): 'Please collect the required statistics using `collector.start()` and evaluating' ' the model for enough batches.' % name) - def start_laplace(self): + def start_second_pass(self): self._check_required_stats() - self.collecting_laplace = True + self.collecting_second_pass = True # reset batch_idx for all leaf modules for module in self.model.modules(): if distiller.has_children(module) or isinstance(module, torch.nn.Identity): continue module.batch_idx = 0 + for record in module.quant_stats.inputs: + record['total_numel'] = 0 + module.quant_stats.output['total_numel'] = 0 - def stop_laplace(self): - self.collecting_laplace = False + def stop_second_pass(self): + self.collecting_second_pass = False def _activation_stats_cb(self, module, inputs, output): - def update_mean(old_mean, new_val): - return old_mean + (new_val - old_mean) / module.batch_idx - - def update_std(values, old_std, old_mean, new_mean): - # See here: - # https://en.wikipedia.org/wiki/Algorithms_for_calculating_variance#Welford's_Online_algorithm - numel = values.numel() if isinstance(values, torch.Tensor) else values.size - total_values_so_far = numel * (module.batch_idx - 1) - M = (old_std ** 2) * (total_values_so_far - 1) - mean_diffs = (values - old_mean) * (values - new_mean) - M += mean_diffs.sum() - return sqrt((M / (total_values_so_far + numel - 1)).item()) - - def update_b(values, old_b, mean): + """ + A callback for updating the required statistics for quantization in a module. + """ + def update_running_mean(values, prev_mean, total_values_so_far): + """ + Updates a running mean of a tensor of values + Args: + values (torch.Tensor): the new tensor + prev_mean (float): the previous running mean + total_values_so_far (int): the number of the values so far + """ + curr_numel = values.numel() + prev_numel = total_values_so_far + return (prev_numel * prev_mean + values.sum().item()) / (prev_numel + curr_numel) + + def update_std(values, prev_std, mean, total_values_so_far): + """ + Updates std of the tensor + """ + prev_variance = prev_std ** 2 + curr_sqr_dists = (values - mean) ** 2 + new_variance = update_running_mean(curr_sqr_dists, prev_variance, total_values_so_far) + return sqrt(new_variance) + + def update_b(values, previous_b, mean, total_values_so_far): """ Updates the 'b' parameter of Laplace Distribution. """ - current_b = (values - mean).abs().mean().item() - return old_b + (current_b - old_b) / module.batch_idx + curr_abs_dists = (values - mean).abs() + return update_running_mean(curr_abs_dists, previous_b, total_values_so_far) def update_record(record, tensor): + if tensor.dtype not in [torch.float16, torch.float32, torch.float64]: + # Mean function only works for float tensors + tensor = tensor.to(torch.float32) if not tensor.is_contiguous(): tensor = tensor.contiguous() act = tensor.view(tensor.size(0), -1) - if self.collecting_laplace: - record['b'] = update_b(act, record['b'], record['mean']) + numel = act.numel() + if self.collecting_second_pass: + record['b'] = update_b(act, record['b'], record['mean'], record['total_numel']) + record['std'] = update_std(act, record['std'], record['mean'], record['total_numel']) + record['total_numel'] += numel return # In the general case, the average min/max that we're collecting are averages over the per-sample @@ -472,23 +519,17 @@ class QuantCalibrationStatsCollector(ActivationStatsCollector): # But - If each sample contains just a single value, then such a per-sample calculation we'll result in # avg_min = avg_max. So in that case we "revert" to calculating "global" values, for the whole batch, # instead of per-sample values - dim = 0 if act.numel() == act.shape[0] else 1 + dim = 0 if numel == act.shape[0] else 1 min_per_sample = act.min(dim=dim)[0] max_per_sample = act.max(dim=dim)[0] record['min'] = min(record['min'], min_per_sample.min().item()) record['max'] = max(record['max'], max_per_sample.max().item()) - try: - record['avg_min'] = update_mean(record['avg_min'], min_per_sample.mean().item()) - record['avg_max'] = update_mean(record['avg_max'], max_per_sample.mean().item()) - new_mean = update_mean(record['mean'], act.mean().item()) - record['std'] = update_std(tensor, record['std'], record['mean'], new_mean) - except RuntimeError: - record['avg_min'] = update_mean(record['avg_min'], min_per_sample.cpu().numpy().mean().item(0)) - record['avg_max'] = update_mean(record['avg_max'], max_per_sample.cpu().numpy().mean().item(0)) - new_mean = update_mean(record['mean'], act.cpu().numpy().mean().item(0)) - record['std'] = update_std(tensor.cpu().numpy(), record['std'], record['mean'], new_mean) + record['avg_min'] = update_running_mean(min_per_sample, record['avg_min'], record['total_numel']) + record['avg_max'] = update_running_mean(max_per_sample, record['avg_max'], record['total_numel']) + new_mean = update_running_mean(act, record['mean'], record['total_numel']) record['mean'] = new_mean + record['total_numel'] += numel if not record['shape']: record['shape'] = distiller.size2str(tensor) @@ -742,13 +783,13 @@ def collect_quant_stats(model, test_fn, save_dir=None, classes=None, inplace_run disable_inplace_attrs=disable_inplace_attrs, inplace_attr_names=inplace_attr_names) with collector_context(quant_stats_collector, modules_to_collect): - msglogger.info('Pass 1: Collecting min, max, avg_min, avg_max, mean, std') + msglogger.info('Pass 1: Collecting min, max, avg_min, avg_max, mean') test_fn(model=model) # Collect Laplace distribution stats: - msglogger.info('Pass 2: Collecting b parameter') - quant_stats_collector.start_laplace() + msglogger.info('Pass 2: Collecting b, std parameters') + quant_stats_collector.start_second_pass() test_fn(model=model) - quant_stats_collector.stop_laplace() + quant_stats_collector.stop_second_pass() msglogger.info('Stats collection complete') if save_dir is not None: