From 3b4633b70733875b2cc3be53e0177bccc43822c7 Mon Sep 17 00:00:00 2001
From: Neta Zmora <neta.zmora@intel.com>
Date: Mon, 24 Jun 2019 16:55:18 +0300
Subject: [PATCH] activation stats collection: add more layer types to default
 module list

Added ReLU6 and LeakyReLU to the list of default module types for which
we collect activation stats.
---
 distiller/data_loggers/collector.py | 8 ++++++--
 1 file changed, 6 insertions(+), 2 deletions(-)

diff --git a/distiller/data_loggers/collector.py b/distiller/data_loggers/collector.py
index 11702b3..68fb607 100755
--- a/distiller/data_loggers/collector.py
+++ b/distiller/data_loggers/collector.py
@@ -146,7 +146,9 @@ class SummaryActivationStatsCollector(ActivationStatsCollector):
     light-weight and quicker than collecting a record per activation.
     The statistic function is configured in the constructor.
     """
-    def __init__(self, model, stat_name, summary_fn, classes=[torch.nn.ReLU]):
+    def __init__(self, model, stat_name, summary_fn, classes=[torch.nn.ReLU,
+                                                              torch.nn.ReLU6,
+                                                              torch.nn.LeakyReLU]):
         super(SummaryActivationStatsCollector, self).__init__(model, stat_name, classes)
         self.summary_fn = summary_fn
 
@@ -223,7 +225,9 @@ class RecordsActivationStatsCollector(ActivationStatsCollector):
 
     For obvious reasons, this is slower than SummaryActivationStatsCollector.
     """
-    def __init__(self, model, classes=[torch.nn.ReLU]):
+    def __init__(self, model, classes=[torch.nn.ReLU,
+                                       torch.nn.ReLU6,
+                                       torch.nn.LeakyReLU]):
         super(RecordsActivationStatsCollector, self).__init__(model, "statistics_records", classes)
 
     def _activation_stats_cb(self, module, input, output):
-- 
GitLab