Skip to content
Snippets Groups Projects
Commit 6cd78f46 authored by Neta Zmora's avatar Neta Zmora
Browse files

Change the way a module callback resolves the module name

When we are traversing the forward path of a graph, by invoking each
module's forward_hook callback, we sometimes want to know the full
name of the module.
Previously, to infer the module name, we looked up the name of self.weight
parameter and used that to get the module name.
In PyTorch 0.4 we can directly look up the module name using
model_find_module_name.
parent 957e6777
No related branches found
No related tags found
No related merge requests found
...@@ -52,6 +52,22 @@ def model_find_param_name(model, param_to_find): ...@@ -52,6 +52,22 @@ def model_find_param_name(model, param_to_find):
return name return name
return None return None
def model_find_module_name(model, module_to_find):
"""Look up the name of a module in a model.
Arguments:
model: the model to search
module_to_find: the module whose name we want to look up
Returns:
The module name (string) or None, if the module was not found.
"""
for name, m in model.named_modules():
if m == module_to_find:
return name
return None
def model_find_param(model, param_to_find_name): def model_find_param(model, param_to_find_name):
"""Look a model parameter by its name """Look a model parameter by its name
......
...@@ -134,6 +134,7 @@ def conv_visitor(self, input, output, df, model, memo): ...@@ -134,6 +134,7 @@ def conv_visitor(self, input, output, df, model, memo):
assert isinstance(self, torch.nn.Conv2d) assert isinstance(self, torch.nn.Conv2d)
if self in memo: if self in memo:
return return
weights_vol = self.out_channels * self.in_channels * self.kernel_size[0] * self.kernel_size[1] weights_vol = self.out_channels * self.in_channels * self.kernel_size[0] * self.kernel_size[1]
# Multiply-accumulate operations: MACs = volume(OFM) * (#IFM * K^2) / #Groups # Multiply-accumulate operations: MACs = volume(OFM) * (#IFM * K^2) / #Groups
...@@ -158,10 +159,7 @@ def module_visitor(self, input, output, df, model, weights_vol, macs, attrs=None ...@@ -158,10 +159,7 @@ def module_visitor(self, input, output, df, model, weights_vol, macs, attrs=None
in_features_shape = input[0].size() in_features_shape = input[0].size()
out_features_shape = output.size() out_features_shape = output.size()
param_name = distiller.model_find_param_name(model, self.weight) mod_name = distiller.model_find_module_name(model, self)
if param_name is None:
return
mod_name = param_name[:param_name.find(".weight")]
df.loc[len(df.index)] = ([mod_name, self.__class__.__name__, df.loc[len(df.index)] = ([mod_name, self.__class__.__name__,
attrs if attrs is not None else '', attrs if attrs is not None else '',
distiller.size_to_str(in_features_shape), distiller.volume(input[0]), distiller.size_to_str(in_features_shape), distiller.volume(input[0]),
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment