From cc50035ef8d633cd440fbbeea8ff73c86b586124 Mon Sep 17 00:00:00 2001
From: Bar <29775567+barrh@users.noreply.github.com>
Date: Wed, 18 Dec 2019 14:00:29 +0200
Subject: [PATCH] IFM sparsity collector (#443)

Add directionality to SummaryActivationStatsCollector to allow collection of statistics on incoming and outgoing activations/feature-maps; instead of just outgoing activations.

Also includes some code refactoring.
---
 distiller/apputils/image_classifier.py |  4 +-
 distiller/data_loggers/collector.py    | 59 +++++++++++++++++---------
 2 files changed, 40 insertions(+), 23 deletions(-)

diff --git a/distiller/apputils/image_classifier.py b/distiller/apputils/image_classifier.py
index c61fbf6..172bfeb 100755
--- a/distiller/apputils/image_classifier.py
+++ b/distiller/apputils/image_classifier.py
@@ -443,8 +443,8 @@ def create_activation_stats_collectors(model, *phases):
             return None  # note, does *not* set self[key] - we don't want defaultdict's behavior
 
     genCollectors = lambda: missingdict({
-        "sparsity":      SummaryActivationStatsCollector(model, "sparsity",
-                                                         lambda t: 100 * distiller.utils.sparsity(t)),
+        "sparsity_ofm":      SummaryActivationStatsCollector(model, "sparsity_ofm",
+            lambda t: 100 * distiller.utils.sparsity(t)),
         "l1_channels":   SummaryActivationStatsCollector(model, "l1_channels",
                                                          distiller.utils.activation_channels_l1),
         "apoz_channels": SummaryActivationStatsCollector(model, "apoz_channels",
diff --git a/distiller/data_loggers/collector.py b/distiller/data_loggers/collector.py
index 27dcafc..b743072 100755
--- a/distiller/data_loggers/collector.py
+++ b/distiller/data_loggers/collector.py
@@ -14,9 +14,11 @@
 # limitations under the License.
 #
 
+import contextlib
 from functools import partial, reduce
 import operator
 import xlsxwriter
+import enum
 import yaml
 import os
 from sys import float_info
@@ -37,9 +39,17 @@ msglogger = logging.getLogger()
 
 __all__ = ['SummaryActivationStatsCollector', 'RecordsActivationStatsCollector',
            'QuantCalibrationStatsCollector', 'ActivationHistogramsCollector',
+           'CollectorDirection',
            'collect_quant_stats', 'collect_histograms',
            'collector_context', 'collectors_context']
 
+class CollectorDirection(enum.Enum):
+    OUT = 0
+    OFM = 0
+    IN = 1
+    IFM = 1
+    IFMS = 1
+
 
 class ActivationStatsCollector(object):
     """Collect model activation statistics information.
@@ -133,7 +143,7 @@ class ActivationStatsCollector(object):
     def save(self, fname):
         raise NotImplementedError
 
-    def _activation_stats_cb(self, module, input, output):
+    def _activation_stats_cb(self, module, inputs, output):
         """Handle new activations ('output' argument).
 
         This is invoked from the forward-pass callback of module 'module'.
@@ -201,20 +211,27 @@ class SummaryActivationStatsCollector(ActivationStatsCollector):
     This Collector computes the mean of some statistic of the activation.  It is rather
     light-weight and quicker than collecting a record per activation.
     The statistic function is configured in the constructor.
+
+    collector_direction - enum type: IN for IFMs, OUT for OFM
+    inputs_consolidate_func is called on tuple of tensors, and returns a tensor.
     """
-    def __init__(self, model, stat_name, summary_fn, classes=[torch.nn.ReLU,
-                                                              torch.nn.ReLU6,
-                                                              torch.nn.LeakyReLU]):
+    def __init__(self, model, stat_name, summary_fn,
+                 classes=[torch.nn.ReLU, torch.nn.ReLU6, torch.nn.LeakyReLU],
+                 collector_direction=CollectorDirection.OUT,
+                 inputs_consolidate_func=torch.cat):
         super(SummaryActivationStatsCollector, self).__init__(model, stat_name, classes)
         self.summary_fn = summary_fn
+        self.collector_direction = collector_direction
+        self.inputs_func = inputs_consolidate_func
 
-    def _activation_stats_cb(self, module, input, output):
+    def _activation_stats_cb(self, module, inputs, output):
         """Record the activation sparsity of 'module'
 
         This is a callback from the forward() of 'module'.
         """
+        feature_map = output if self.collector_direction == CollectorDirection.OUT else self.inputs_func(inputs)
         try:
-            getattr(module, self.stat_name).add(self.summary_fn(output.data), output.data.numel())
+            getattr(module, self.stat_name).add(self.summary_fn(feature_map.data), feature_map.data.numel())
         except RuntimeError as e:
             if "The expanded size of the tensor" in e.args[0]:
                 raise ValueError("ActivationStatsCollector: a module ({} - {}) was encountered twice during model.apply().\n"
@@ -233,11 +250,10 @@ class SummaryActivationStatsCollector(ActivationStatsCollector):
             setattr(module, self.stat_name, WeightedAverageValueMeter())
             # Assign a name to this summary
             if hasattr(module, 'distiller_name'):
-                getattr(module, self.stat_name).name = '_'.join((self.stat_name, module.distiller_name))
+                getattr(module, self.stat_name).name = module.distiller_name
             else:
-                getattr(module, self.stat_name).name = '_'.join((self.stat_name,
-                                                                 module.__class__.__name__,
-                                                                 str(id(module))))
+                getattr(module, self.stat_name).name = '_'.join((
+                    module.__class__.__name__, str(id(module))))
 
     def _reset_counter(self, module):
         if hasattr(module, self.stat_name):
@@ -251,28 +267,29 @@ class SummaryActivationStatsCollector(ActivationStatsCollector):
             activation_stats[getattr(module, self.stat_name).name] = mean
 
     def save(self, fname):
-        """Save the records to an Excel workbook, with one worksheet per layer.
-        """
-        fname = ".".join([fname, 'xlsx'])
-        try:
+        """Save the stats to an Excel workbook"""
+        if not fname.endswith('.xlsx'):
+            fname = '.'.join([fname, 'xlsx'])
+        with contextlib.suppress(OSError):
             os.remove(fname)
-        except OSError:
-            pass
 
-        records_dict = self.value()
-        with xlsxwriter.Workbook(fname) as workbook:
+        def _add_worksheet(workbook, tab_name, record):
             try:
-                worksheet = workbook.add_worksheet(self.stat_name)
+                worksheet = workbook.add_worksheet(tab_name)
             except xlsxwriter.exceptions.InvalidWorksheetName:
                 worksheet = workbook.add_worksheet()
 
             col_names = []
-            for col, (module_name, module_summary_data) in enumerate(records_dict.items()):
+            for col, (module_name, module_summary_data) in enumerate(record.items()):
                 if not isinstance(module_summary_data, list):
                     module_summary_data = [module_summary_data]
                 worksheet.write_column(1, col, module_summary_data)
                 col_names.append(module_name)
             worksheet.write_row(0, 0, col_names)
+
+        with xlsxwriter.Workbook(fname) as workbook:
+            _add_worksheet(workbook, self.stat_name, self.value())
+
         return fname
 
 
@@ -290,7 +307,7 @@ class RecordsActivationStatsCollector(ActivationStatsCollector):
                                        torch.nn.LeakyReLU]):
         super(RecordsActivationStatsCollector, self).__init__(model, "statistics_records", classes)
 
-    def _activation_stats_cb(self, module, input, output):
+    def _activation_stats_cb(self, module, inputs, output):
         """Record the activation sparsity of 'module'
 
         This is a callback from the forward() of 'module'.
-- 
GitLab