From ab8d2960fad83b27d617745f7dd3b84dac819977 Mon Sep 17 00:00:00 2001 From: Bar <elhararb@gmail.com> Date: Tue, 14 May 2019 16:38:19 +0300 Subject: [PATCH] Improved logging when saving collectors data (#251) --- distiller/data_loggers/collector.py | 6 +++++- examples/classifier_compression/compress_classifier.py | 10 ++++++---- 2 files changed, 11 insertions(+), 5 deletions(-) diff --git a/distiller/data_loggers/collector.py b/distiller/data_loggers/collector.py index 9ceed23..9b0350b 100755 --- a/distiller/data_loggers/collector.py +++ b/distiller/data_loggers/collector.py @@ -113,7 +113,7 @@ class ActivationStatsCollector(object): return self def save(self, fname): - pass + raise NotImplementedError def _activation_stats_cb(self, module, input, output): """Handle new activations ('output' argument). @@ -207,6 +207,7 @@ class SummaryActivationStatsCollector(ActivationStatsCollector): worksheet.write_column(1, col, module_summary_data) col_names.append(module_name) worksheet.write_row(0, 0, col_names) + return fname class RecordsActivationStatsCollector(ActivationStatsCollector): @@ -283,6 +284,7 @@ class RecordsActivationStatsCollector(ActivationStatsCollector): col_names.append(col_name) worksheet.write_row(0, 0, col_names) worksheet.write(0, len(col_names)+2, module_act_records['shape']) + return fname def _start_counter(self, module): if not hasattr(module, "statistics_records"): @@ -451,6 +453,8 @@ class QuantCalibrationStatsCollector(ActivationStatsCollector): with open(fname, 'w') as f: yaml.dump(records_dict, f, default_flow_style=False) + return fname + @contextmanager def collector_context(collector): diff --git a/examples/classifier_compression/compress_classifier.py b/examples/classifier_compression/compress_classifier.py index ad499a7..1582f87 100755 --- a/examples/classifier_compression/compress_classifier.py +++ b/examples/classifier_compression/compress_classifier.py @@ -755,12 +755,14 @@ def create_quantization_stats_collector(model): def save_collectors_data(collectors, directory): - """Utility function that saves all activation statistics to Excel workbooks + """Utility function that saves all activation statistics to disk. + + File type and format of contents are collector-specific. """ for name, collector in collectors.items(): - workbook = os.path.join(directory, name) - msglogger.info("Generating {}".format(workbook)) - collector.save(workbook) + msglogger.info('Saving data for collector {}...'.format(name)) + file_path = collector.save(os.path.join(directory, name)) + msglogger.info("Saved to {}".format(file_path)) def check_pytorch_version(): -- GitLab