diff --git a/distiller/data_loggers/collector.py b/distiller/data_loggers/collector.py index 11702b30cf95ca1b26388ddfbb1c6aa2beeeb5a8..68fb60787e732db183947a2d7ede993bebef0d2f 100755 --- a/distiller/data_loggers/collector.py +++ b/distiller/data_loggers/collector.py @@ -146,7 +146,9 @@ class SummaryActivationStatsCollector(ActivationStatsCollector): light-weight and quicker than collecting a record per activation. The statistic function is configured in the constructor. """ - def __init__(self, model, stat_name, summary_fn, classes=[torch.nn.ReLU]): + def __init__(self, model, stat_name, summary_fn, classes=[torch.nn.ReLU, + torch.nn.ReLU6, + torch.nn.LeakyReLU]): super(SummaryActivationStatsCollector, self).__init__(model, stat_name, classes) self.summary_fn = summary_fn @@ -223,7 +225,9 @@ class RecordsActivationStatsCollector(ActivationStatsCollector): For obvious reasons, this is slower than SummaryActivationStatsCollector. """ - def __init__(self, model, classes=[torch.nn.ReLU]): + def __init__(self, model, classes=[torch.nn.ReLU, + torch.nn.ReLU6, + torch.nn.LeakyReLU]): super(RecordsActivationStatsCollector, self).__init__(model, "statistics_records", classes) def _activation_stats_cb(self, module, input, output):