From 6cd78f4646f7b6876d132e6307dc99e0186103a3 Mon Sep 17 00:00:00 2001 From: Neta Zmora <neta.zmora@intel.com> Date: Sun, 13 May 2018 12:36:04 +0300 Subject: [PATCH] Change the way a module callback resolves the module name When we are traversing the forward path of a graph, by invoking each module's forward_hook callback, we sometimes want to know the full name of the module. Previously, to infer the module name, we looked up the name of self.weight parameter and used that to get the module name. In PyTorch 0.4 we can directly look up the module name using model_find_module_name. --- distiller/__init__.py | 16 ++++++++++++++++ distiller/model_summaries.py | 6 ++---- 2 files changed, 18 insertions(+), 4 deletions(-) diff --git a/distiller/__init__.py b/distiller/__init__.py index c63517d..1249b1d 100755 --- a/distiller/__init__.py +++ b/distiller/__init__.py @@ -52,6 +52,22 @@ def model_find_param_name(model, param_to_find): return name return None + +def model_find_module_name(model, module_to_find): + """Look up the name of a module in a model. + + Arguments: + model: the model to search + module_to_find: the module whose name we want to look up + + Returns: + The module name (string) or None, if the module was not found. + """ + for name, m in model.named_modules(): + if m == module_to_find: + return name + return None + def model_find_param(model, param_to_find_name): """Look a model parameter by its name diff --git a/distiller/model_summaries.py b/distiller/model_summaries.py index 376e91e..4897779 100755 --- a/distiller/model_summaries.py +++ b/distiller/model_summaries.py @@ -134,6 +134,7 @@ def conv_visitor(self, input, output, df, model, memo): assert isinstance(self, torch.nn.Conv2d) if self in memo: return + weights_vol = self.out_channels * self.in_channels * self.kernel_size[0] * self.kernel_size[1] # Multiply-accumulate operations: MACs = volume(OFM) * (#IFM * K^2) / #Groups @@ -158,10 +159,7 @@ def module_visitor(self, input, output, df, model, weights_vol, macs, attrs=None in_features_shape = input[0].size() out_features_shape = output.size() - param_name = distiller.model_find_param_name(model, self.weight) - if param_name is None: - return - mod_name = param_name[:param_name.find(".weight")] + mod_name = distiller.model_find_module_name(model, self) df.loc[len(df.index)] = ([mod_name, self.__class__.__name__, attrs if attrs is not None else '', distiller.size_to_str(in_features_shape), distiller.volume(input[0]), -- GitLab