From 6cd78f4646f7b6876d132e6307dc99e0186103a3 Mon Sep 17 00:00:00 2001
From: Neta Zmora <neta.zmora@intel.com>
Date: Sun, 13 May 2018 12:36:04 +0300
Subject: [PATCH] Change the way a module callback resolves the module name

When we are traversing the forward path of a graph, by invoking each
module's forward_hook callback, we sometimes want to know the full
name of the module.
Previously, to infer the module name, we looked up the name of self.weight
parameter and used that to get the module name.
In PyTorch 0.4 we can directly look up the module name using
model_find_module_name.
---
 distiller/__init__.py        | 16 ++++++++++++++++
 distiller/model_summaries.py |  6 ++----
 2 files changed, 18 insertions(+), 4 deletions(-)

diff --git a/distiller/__init__.py b/distiller/__init__.py
index c63517d..1249b1d 100755
--- a/distiller/__init__.py
+++ b/distiller/__init__.py
@@ -52,6 +52,22 @@ def model_find_param_name(model, param_to_find):
             return name
     return None
 
+
+def model_find_module_name(model, module_to_find):
+    """Look up the name of a module in a model.
+
+    Arguments:
+        model: the model to search
+        module_to_find: the module whose name we want to look up
+
+    Returns:
+        The module name (string) or None, if the module was not found.
+    """
+    for name, m in model.named_modules():
+        if m == module_to_find:
+            return name
+    return None
+
 def model_find_param(model, param_to_find_name):
     """Look a model parameter by its name
 
diff --git a/distiller/model_summaries.py b/distiller/model_summaries.py
index 376e91e..4897779 100755
--- a/distiller/model_summaries.py
+++ b/distiller/model_summaries.py
@@ -134,6 +134,7 @@ def conv_visitor(self, input, output, df, model, memo):
     assert isinstance(self, torch.nn.Conv2d)
     if self in memo:
         return
+
     weights_vol = self.out_channels * self.in_channels * self.kernel_size[0] * self.kernel_size[1]
 
     # Multiply-accumulate operations: MACs = volume(OFM) * (#IFM * K^2) / #Groups
@@ -158,10 +159,7 @@ def module_visitor(self, input, output, df, model, weights_vol, macs, attrs=None
     in_features_shape = input[0].size()
     out_features_shape = output.size()
 
-    param_name = distiller.model_find_param_name(model, self.weight)
-    if param_name is None:
-        return
-    mod_name = param_name[:param_name.find(".weight")]
+    mod_name = distiller.model_find_module_name(model, self)
     df.loc[len(df.index)] = ([mod_name, self.__class__.__name__,
                               attrs if attrs is not None else '',
                               distiller.size_to_str(in_features_shape), distiller.volume(input[0]),
-- 
GitLab