diff --git a/distiller/data_loggers/collector.py b/distiller/data_loggers/collector.py index 9ceed2357737f0324237f18d2694851573710fca..9b0350b3e01613a84ae515c32dd9aa0511c3dea2 100755 --- a/distiller/data_loggers/collector.py +++ b/distiller/data_loggers/collector.py @@ -113,7 +113,7 @@ class ActivationStatsCollector(object): return self def save(self, fname): - pass + raise NotImplementedError def _activation_stats_cb(self, module, input, output): """Handle new activations ('output' argument). @@ -207,6 +207,7 @@ class SummaryActivationStatsCollector(ActivationStatsCollector): worksheet.write_column(1, col, module_summary_data) col_names.append(module_name) worksheet.write_row(0, 0, col_names) + return fname class RecordsActivationStatsCollector(ActivationStatsCollector): @@ -283,6 +284,7 @@ class RecordsActivationStatsCollector(ActivationStatsCollector): col_names.append(col_name) worksheet.write_row(0, 0, col_names) worksheet.write(0, len(col_names)+2, module_act_records['shape']) + return fname def _start_counter(self, module): if not hasattr(module, "statistics_records"): @@ -451,6 +453,8 @@ class QuantCalibrationStatsCollector(ActivationStatsCollector): with open(fname, 'w') as f: yaml.dump(records_dict, f, default_flow_style=False) + return fname + @contextmanager def collector_context(collector): diff --git a/examples/classifier_compression/compress_classifier.py b/examples/classifier_compression/compress_classifier.py index ad499a75451551d58f94f730e31772b1c38b0712..1582f8726352bd3cb245ed5ba36bfc342c3c5e56 100755 --- a/examples/classifier_compression/compress_classifier.py +++ b/examples/classifier_compression/compress_classifier.py @@ -755,12 +755,14 @@ def create_quantization_stats_collector(model): def save_collectors_data(collectors, directory): - """Utility function that saves all activation statistics to Excel workbooks + """Utility function that saves all activation statistics to disk. + + File type and format of contents are collector-specific. """ for name, collector in collectors.items(): - workbook = os.path.join(directory, name) - msglogger.info("Generating {}".format(workbook)) - collector.save(workbook) + msglogger.info('Saving data for collector {}...'.format(name)) + file_path = collector.save(os.path.join(directory, name)) + msglogger.info("Saved to {}".format(file_path)) def check_pytorch_version():