From cc50035ef8d633cd440fbbeea8ff73c86b586124 Mon Sep 17 00:00:00 2001 From: Bar <29775567+barrh@users.noreply.github.com> Date: Wed, 18 Dec 2019 14:00:29 +0200 Subject: [PATCH] IFM sparsity collector (#443) Add directionality to SummaryActivationStatsCollector to allow collection of statistics on incoming and outgoing activations/feature-maps; instead of just outgoing activations. Also includes some code refactoring. --- distiller/apputils/image_classifier.py | 4 +- distiller/data_loggers/collector.py | 59 +++++++++++++++++--------- 2 files changed, 40 insertions(+), 23 deletions(-) diff --git a/distiller/apputils/image_classifier.py b/distiller/apputils/image_classifier.py index c61fbf6..172bfeb 100755 --- a/distiller/apputils/image_classifier.py +++ b/distiller/apputils/image_classifier.py @@ -443,8 +443,8 @@ def create_activation_stats_collectors(model, *phases): return None # note, does *not* set self[key] - we don't want defaultdict's behavior genCollectors = lambda: missingdict({ - "sparsity": SummaryActivationStatsCollector(model, "sparsity", - lambda t: 100 * distiller.utils.sparsity(t)), + "sparsity_ofm": SummaryActivationStatsCollector(model, "sparsity_ofm", + lambda t: 100 * distiller.utils.sparsity(t)), "l1_channels": SummaryActivationStatsCollector(model, "l1_channels", distiller.utils.activation_channels_l1), "apoz_channels": SummaryActivationStatsCollector(model, "apoz_channels", diff --git a/distiller/data_loggers/collector.py b/distiller/data_loggers/collector.py index 27dcafc..b743072 100755 --- a/distiller/data_loggers/collector.py +++ b/distiller/data_loggers/collector.py @@ -14,9 +14,11 @@ # limitations under the License. # +import contextlib from functools import partial, reduce import operator import xlsxwriter +import enum import yaml import os from sys import float_info @@ -37,9 +39,17 @@ msglogger = logging.getLogger() __all__ = ['SummaryActivationStatsCollector', 'RecordsActivationStatsCollector', 'QuantCalibrationStatsCollector', 'ActivationHistogramsCollector', + 'CollectorDirection', 'collect_quant_stats', 'collect_histograms', 'collector_context', 'collectors_context'] +class CollectorDirection(enum.Enum): + OUT = 0 + OFM = 0 + IN = 1 + IFM = 1 + IFMS = 1 + class ActivationStatsCollector(object): """Collect model activation statistics information. @@ -133,7 +143,7 @@ class ActivationStatsCollector(object): def save(self, fname): raise NotImplementedError - def _activation_stats_cb(self, module, input, output): + def _activation_stats_cb(self, module, inputs, output): """Handle new activations ('output' argument). This is invoked from the forward-pass callback of module 'module'. @@ -201,20 +211,27 @@ class SummaryActivationStatsCollector(ActivationStatsCollector): This Collector computes the mean of some statistic of the activation. It is rather light-weight and quicker than collecting a record per activation. The statistic function is configured in the constructor. + + collector_direction - enum type: IN for IFMs, OUT for OFM + inputs_consolidate_func is called on tuple of tensors, and returns a tensor. """ - def __init__(self, model, stat_name, summary_fn, classes=[torch.nn.ReLU, - torch.nn.ReLU6, - torch.nn.LeakyReLU]): + def __init__(self, model, stat_name, summary_fn, + classes=[torch.nn.ReLU, torch.nn.ReLU6, torch.nn.LeakyReLU], + collector_direction=CollectorDirection.OUT, + inputs_consolidate_func=torch.cat): super(SummaryActivationStatsCollector, self).__init__(model, stat_name, classes) self.summary_fn = summary_fn + self.collector_direction = collector_direction + self.inputs_func = inputs_consolidate_func - def _activation_stats_cb(self, module, input, output): + def _activation_stats_cb(self, module, inputs, output): """Record the activation sparsity of 'module' This is a callback from the forward() of 'module'. """ + feature_map = output if self.collector_direction == CollectorDirection.OUT else self.inputs_func(inputs) try: - getattr(module, self.stat_name).add(self.summary_fn(output.data), output.data.numel()) + getattr(module, self.stat_name).add(self.summary_fn(feature_map.data), feature_map.data.numel()) except RuntimeError as e: if "The expanded size of the tensor" in e.args[0]: raise ValueError("ActivationStatsCollector: a module ({} - {}) was encountered twice during model.apply().\n" @@ -233,11 +250,10 @@ class SummaryActivationStatsCollector(ActivationStatsCollector): setattr(module, self.stat_name, WeightedAverageValueMeter()) # Assign a name to this summary if hasattr(module, 'distiller_name'): - getattr(module, self.stat_name).name = '_'.join((self.stat_name, module.distiller_name)) + getattr(module, self.stat_name).name = module.distiller_name else: - getattr(module, self.stat_name).name = '_'.join((self.stat_name, - module.__class__.__name__, - str(id(module)))) + getattr(module, self.stat_name).name = '_'.join(( + module.__class__.__name__, str(id(module)))) def _reset_counter(self, module): if hasattr(module, self.stat_name): @@ -251,28 +267,29 @@ class SummaryActivationStatsCollector(ActivationStatsCollector): activation_stats[getattr(module, self.stat_name).name] = mean def save(self, fname): - """Save the records to an Excel workbook, with one worksheet per layer. - """ - fname = ".".join([fname, 'xlsx']) - try: + """Save the stats to an Excel workbook""" + if not fname.endswith('.xlsx'): + fname = '.'.join([fname, 'xlsx']) + with contextlib.suppress(OSError): os.remove(fname) - except OSError: - pass - records_dict = self.value() - with xlsxwriter.Workbook(fname) as workbook: + def _add_worksheet(workbook, tab_name, record): try: - worksheet = workbook.add_worksheet(self.stat_name) + worksheet = workbook.add_worksheet(tab_name) except xlsxwriter.exceptions.InvalidWorksheetName: worksheet = workbook.add_worksheet() col_names = [] - for col, (module_name, module_summary_data) in enumerate(records_dict.items()): + for col, (module_name, module_summary_data) in enumerate(record.items()): if not isinstance(module_summary_data, list): module_summary_data = [module_summary_data] worksheet.write_column(1, col, module_summary_data) col_names.append(module_name) worksheet.write_row(0, 0, col_names) + + with xlsxwriter.Workbook(fname) as workbook: + _add_worksheet(workbook, self.stat_name, self.value()) + return fname @@ -290,7 +307,7 @@ class RecordsActivationStatsCollector(ActivationStatsCollector): torch.nn.LeakyReLU]): super(RecordsActivationStatsCollector, self).__init__(model, "statistics_records", classes) - def _activation_stats_cb(self, module, input, output): + def _activation_stats_cb(self, module, inputs, output): """Record the activation sparsity of 'module' This is a callback from the forward() of 'module'. -- GitLab