From ab8d2960fad83b27d617745f7dd3b84dac819977 Mon Sep 17 00:00:00 2001
From: Bar <elhararb@gmail.com>
Date: Tue, 14 May 2019 16:38:19 +0300
Subject: [PATCH] Improved logging when saving collectors data (#251)

---
 distiller/data_loggers/collector.py                    |  6 +++++-
 examples/classifier_compression/compress_classifier.py | 10 ++++++----
 2 files changed, 11 insertions(+), 5 deletions(-)

diff --git a/distiller/data_loggers/collector.py b/distiller/data_loggers/collector.py
index 9ceed23..9b0350b 100755
--- a/distiller/data_loggers/collector.py
+++ b/distiller/data_loggers/collector.py
@@ -113,7 +113,7 @@ class ActivationStatsCollector(object):
         return self
 
     def save(self, fname):
-        pass
+        raise NotImplementedError
 
     def _activation_stats_cb(self, module, input, output):
         """Handle new activations ('output' argument).
@@ -207,6 +207,7 @@ class SummaryActivationStatsCollector(ActivationStatsCollector):
                 worksheet.write_column(1, col, module_summary_data)
                 col_names.append(module_name)
             worksheet.write_row(0, 0, col_names)
+        return fname
 
 
 class RecordsActivationStatsCollector(ActivationStatsCollector):
@@ -283,6 +284,7 @@ class RecordsActivationStatsCollector(ActivationStatsCollector):
                     col_names.append(col_name)
                 worksheet.write_row(0, 0, col_names)
                 worksheet.write(0, len(col_names)+2, module_act_records['shape'])
+        return fname
 
     def _start_counter(self, module):
         if not hasattr(module, "statistics_records"):
@@ -451,6 +453,8 @@ class QuantCalibrationStatsCollector(ActivationStatsCollector):
         with open(fname, 'w') as f:
             yaml.dump(records_dict, f, default_flow_style=False)
 
+        return fname
+
 
 @contextmanager
 def collector_context(collector):
diff --git a/examples/classifier_compression/compress_classifier.py b/examples/classifier_compression/compress_classifier.py
index ad499a7..1582f87 100755
--- a/examples/classifier_compression/compress_classifier.py
+++ b/examples/classifier_compression/compress_classifier.py
@@ -755,12 +755,14 @@ def create_quantization_stats_collector(model):
 
 
 def save_collectors_data(collectors, directory):
-    """Utility function that saves all activation statistics to Excel workbooks
+    """Utility function that saves all activation statistics to disk.
+
+    File type and format of contents are collector-specific.
     """
     for name, collector in collectors.items():
-        workbook = os.path.join(directory, name)
-        msglogger.info("Generating {}".format(workbook))
-        collector.save(workbook)
+        msglogger.info('Saving data for collector {}...'.format(name))
+        file_path = collector.save(os.path.join(directory, name))
+        msglogger.info("Saved to {}".format(file_path))
 
 
 def check_pytorch_version():
-- 
GitLab