diff --git a/distiller/data_loggers/collector.py b/distiller/data_loggers/collector.py index a18a368efa3061bf1a0c594b1be799d83bc8e62d..c5d7977674048c18aaa072e5e17ac2bc4a6f4122 100755 --- a/distiller/data_loggers/collector.py +++ b/distiller/data_loggers/collector.py @@ -14,6 +14,7 @@ # limitations under the License. # +from functools import partial import xlsxwriter import os from collections import OrderedDict @@ -76,14 +77,9 @@ class ActivationStatsCollector(object): def value(self): """Return a dictionary containing {layer_name: statistic}""" activation_stats = OrderedDict() - self.__value(self.model, activation_stats) + self.model.apply(partial(self._collect_activations_stats, activation_stats=activation_stats)) return activation_stats - def __value(self, module, activation_stats): - for child_module in module._modules.values(): - self.__value(child_module, activation_stats) - self._collect_activations_stats(module, activation_stats) - def start(self): """Start collecting activation stats. @@ -91,7 +87,17 @@ class ActivationStatsCollector(object): will be called from the forward traversal and get exposed to activation data. """ assert len(self.fwd_hook_handles) == 0 - self.__start(self.model) + self.model.apply(self.start_module) + + def start_module(self, module): + """Iteratively register to the forward-pass callback of all eligable modules. + + Eligable modules are currently filtered by their class type. + """ + is_leaf_node = len(list(module.children())) == 0 + if is_leaf_node and type(module) in self.classes: + self.fwd_hook_handles.append(module.register_forward_hook(self._activation_stats_cb)) + self._start_counter(module) def stop(self): """Stop collecting activation stats. @@ -104,14 +110,9 @@ class ActivationStatsCollector(object): def reset(self): """Reset the statsitics counters of this collector.""" - self.__reset(self.model) + self.model.apply(self._reset_counter) return self - def __reset(self, module): - for child_module in module._modules.values(): - self.__reset(child_module) - self._reset_counter(module) - def __activation_stats_cb(self, module, input, output): """Handle new activations ('output' argument). @@ -119,22 +120,7 @@ class ActivationStatsCollector(object): """ raise NotImplementedError - def __start(self, module, name=''): - """Iteratively register to the forward-pass callback of all eligable modules. - - Eligable modules are currently filtered by their class type. - """ - is_leaf_node = True - for name, sub_module in module._modules.items(): - self.__start(sub_module, name) - is_leaf_node = False - - if is_leaf_node: - if type(module) in self.classes: - self.fwd_hook_handles.append(module.register_forward_hook(self._activation_stats_cb)) - self._start_counter(module) - - def _reset_counter(self, mod): + def _reset_counter(self, module): """Reset a specific statistic counter - this is subclass-specific code""" raise NotImplementedError @@ -172,9 +158,9 @@ class SummaryActivationStatsCollector(ActivationStatsCollector): module.__class__.__name__, str(id(module)))) - def _reset_counter(self, mod): - if hasattr(mod, self.stat_name): - getattr(mod, self.stat_name).reset() + def _reset_counter(self, module): + if hasattr(module, self.stat_name): + getattr(module, self.stat_name).reset() def _collect_activations_stats(self, module, activation_stats, name=''): if hasattr(module, self.stat_name): @@ -276,9 +262,9 @@ class RecordsActivationStatsCollector(ActivationStatsCollector): if not hasattr(module, "statsitics_records"): module.statsitics_records = self._create_records_dict() - def _reset_counter(self, mod): - if hasattr(mod, "statsitics_records"): - mod.statsitics_records = self._create_records_dict() + def _reset_counter(self, module): + if hasattr(module, "statsitics_records"): + module.statsitics_records = self._create_records_dict() def _collect_activations_stats(self, module, activation_stats, name=''): if hasattr(module, "statsitics_records"): diff --git a/examples/classifier_compression/compress_classifier.py b/examples/classifier_compression/compress_classifier.py index 380a4c455a3571bd129b80cfd5b8231fff1aedaa..7d70958e2f8953a3789b542fea268be30bdb434f 100755 --- a/examples/classifier_compression/compress_classifier.py +++ b/examples/classifier_compression/compress_classifier.py @@ -212,13 +212,14 @@ def create_activation_stats_collectors(model, collection_phase): activations_collectors = {"train": missingdict(), "valid": missingdict(), "test": missingdict()} if collection_phase is None: return activations_collectors - collectors = missingdict() - collectors["sparsity"] = SummaryActivationStatsCollector(model, "sparsity", distiller.utils.sparsity) - collectors["l1_channels"] = SummaryActivationStatsCollector(model, "l1_channels", - distiller.utils.activation_channels_l1) - collectors["apoz_channels"] = SummaryActivationStatsCollector(model, "apoz_channels", - distiller.utils.activation_channels_apoz) - collectors["records"] = RecordsActivationStatsCollector(model, classes=[torch.nn.Conv2d]) + collectors = missingdict({ + "sparsity": SummaryActivationStatsCollector(model, "sparsity", + lambda t: 100 * distiller.utils.sparsity(t)), + "l1_channels": SummaryActivationStatsCollector(model, "l1_channels", + distiller.utils.activation_channels_l1), + "apoz_channels": SummaryActivationStatsCollector(model, "apoz_channels", + distiller.utils.activation_channels_apoz), + "records": RecordsActivationStatsCollector(model, classes=[torch.nn.Conv2d])}) activations_collectors[collection_phase] = collectors return activations_collectors @@ -227,7 +228,9 @@ def save_collectors_data(collectors, directory): """Utility function that saves all activation statistics to Excel workbooks """ for name, collector in collectors.items(): - collector.to_xlsx(os.path.join(directory, name)) + workbook = os.path.join(directory, name) + msglogger.info("Generating {}".format(workbook)) + collector.to_xlsx(workbook) def main(): @@ -538,6 +541,7 @@ def test(test_loader, model, criterion, loggers, activations_collectors, args): with collectors_context(activations_collectors["test"]) as collectors: top1, top5, lossses = _validate(test_loader, model, criterion, loggers, args) distiller.log_activation_statsitics(-1, "test", loggers, collector=collectors['sparsity']) + save_collectors_data(collectors, msglogger.logdir) return top1, top5, lossses