From 3b4633b70733875b2cc3be53e0177bccc43822c7 Mon Sep 17 00:00:00 2001 From: Neta Zmora <neta.zmora@intel.com> Date: Mon, 24 Jun 2019 16:55:18 +0300 Subject: [PATCH] activation stats collection: add more layer types to default module list Added ReLU6 and LeakyReLU to the list of default module types for which we collect activation stats. --- distiller/data_loggers/collector.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/distiller/data_loggers/collector.py b/distiller/data_loggers/collector.py index 11702b3..68fb607 100755 --- a/distiller/data_loggers/collector.py +++ b/distiller/data_loggers/collector.py @@ -146,7 +146,9 @@ class SummaryActivationStatsCollector(ActivationStatsCollector): light-weight and quicker than collecting a record per activation. The statistic function is configured in the constructor. """ - def __init__(self, model, stat_name, summary_fn, classes=[torch.nn.ReLU]): + def __init__(self, model, stat_name, summary_fn, classes=[torch.nn.ReLU, + torch.nn.ReLU6, + torch.nn.LeakyReLU]): super(SummaryActivationStatsCollector, self).__init__(model, stat_name, classes) self.summary_fn = summary_fn @@ -223,7 +225,9 @@ class RecordsActivationStatsCollector(ActivationStatsCollector): For obvious reasons, this is slower than SummaryActivationStatsCollector. """ - def __init__(self, model, classes=[torch.nn.ReLU]): + def __init__(self, model, classes=[torch.nn.ReLU, + torch.nn.ReLU6, + torch.nn.LeakyReLU]): super(RecordsActivationStatsCollector, self).__init__(model, "statistics_records", classes) def _activation_stats_cb(self, module, input, output): -- GitLab