Skip to content
Snippets Groups Projects
Unverified Commit b60a33ef authored by Guy Jacob's avatar Guy Jacob Committed by GitHub
Browse files

SummaryGraph: Add adjacency map + numerous changes (#291)

* Adjacency map - map from each op to its predecessor and successor ops
* More robust handling of Gemm nodes scope names (instead of
  increment_instance())
* More consistent handling of ops with the same scope name
* Handle pad + avg pool sequences generated by ONNX trace optimization
  (results in one less op in the graph, hence the changes in tests)
* Minor refactoring in predecessors() and successors() functions
parent 4c7d4890
No related branches found
No related tags found
No related merge requests found
...@@ -21,45 +21,20 @@ import collections ...@@ -21,45 +21,20 @@ import collections
import torch import torch
import torch.jit as jit import torch.jit as jit
import logging import logging
from collections import OrderedDict from collections import OrderedDict, defaultdict
msglogger = logging.getLogger() msglogger = logging.getLogger()
def onnx_name_2_pytorch_name(name, op_type): def onnx_name_2_pytorch_name(name):
# Convert a layer's name from an ONNX name, to a PyTorch name # Convert a layer's name from an ONNX name, to a PyTorch name
# For example: # For example:
# ResNet/Sequential[layer3]/BasicBlock[0]/ReLU[relu].1 ==> layer3.0.relu.1 # ResNet/Sequential[layer3]/BasicBlock[0]/ReLU[relu] ==> layer3.0.relu
# First see if there's an instance identifier # Split by square brackets
instance = ''
if name.find('.') >= 0:
instance = name[name.find('.')+1:]
# Next, split by square brackets
name_parts = re.findall('\[.*?\]', name) name_parts = re.findall('\[.*?\]', name)
name_parts = [part[1:-1] for part in name_parts] name_parts = [part[1:-1] for part in name_parts]
# If name doesn't have the pattern above, it probably means the op was called via return '.'.join(name_parts)
# some functional API and not via a module. Couple of examples:
# x = x.view(...)
# x = F.relu(x)
# In this case, to have a meaningful name, we use the op type
new_name = ('.'.join(name_parts) if len(name_parts) > 0 else op_type) + instance
msglogger.debug("new sgraph node {} {} {}".format(name, op_type, new_name))
return new_name
def increment_instance(node_name):
"""Increment the instance number of a given node"""
try:
# There is an assumption here that the last character in node_name is the node instance (an integer),
# and that it is between 0-9 (i.e. a digit)
base_name = node_name[:-1]
suffix = str(int(node_name[-1]) + 1)
return base_name + suffix
except ValueError:
return node_name + ".0"
class SummaryGraph(object): class SummaryGraph(object):
...@@ -102,12 +77,20 @@ class SummaryGraph(object): ...@@ -102,12 +77,20 @@ class SummaryGraph(object):
dummy_input = distiller.convert_tensors_recursively_to(dummy_input, device=device) dummy_input = distiller.convert_tensors_recursively_to(dummy_input, device=device)
trace, _ = jit.get_trace_graph(model_clone, dummy_input, _force_outplace=True) trace, _ = jit.get_trace_graph(model_clone, dummy_input, _force_outplace=True)
# ONNX trace optimization has issues with Gemm ops (aka "Linear" / "addmm" / "FC"), where
# Gemm nodes get the scope name of the last non-Gemm node that came before them. This can make
# it impossible, in some cases, to derive the connectivity of the model using the original
# module names. So we save the scope names for these nodes from the un-optimized trace.
aten_addmm_nodes_scope_names = [n.scopeName() for n in trace.graph().nodes() if n.kind() == 'aten::addmm']
onnx_gemm_count = 0
# Let ONNX do the heavy lifting: fusing the convolution nodes; fusing the nodes # Let ONNX do the heavy lifting: fusing the convolution nodes; fusing the nodes
# composing a GEMM operation; etc. # composing a GEMM operation; etc.
torch.onnx._optimize_trace(trace, torch.onnx.OperatorExportTypes.ONNX) torch.onnx._optimize_trace(trace, torch.onnx.OperatorExportTypes.ONNX)
graph = trace.graph() graph = trace.graph()
self.ops = OrderedDict() self.ops = OrderedDict()
self.module_ops_map = defaultdict(list)
self.params = OrderedDict() self.params = OrderedDict()
self.edges = [] self.edges = []
self.temp = OrderedDict() self.temp = OrderedDict()
...@@ -119,26 +102,48 @@ class SummaryGraph(object): ...@@ -119,26 +102,48 @@ class SummaryGraph(object):
for node in graph.nodes(): for node in graph.nodes():
new_op = self.__create_op(node) new_op = self.__create_op(node)
# Operators with the same name create very confusing graphs (Resnet, for example), # Here we apply the workaround to the Gemm nodes scope name issue mentioned above
if new_op['type'] == 'Gemm':
new_op['orig-name'] = aten_addmm_nodes_scope_names[onnx_gemm_count]
new_op['name'] = new_op['orig-name']
onnx_gemm_count += 1
# Convert the graph node's scope name to a PyTorch module name
module_name = onnx_name_2_pytorch_name(new_op['orig-name'])
new_op['module-name'] = module_name
if len(module_name) == 0:
# Special case where the module name is an empty string - this happens
# when the op is called from the "top-level" of the model
new_op['name'] = 'top_level_op'
else:
new_op['name'] = module_name
# The node's scope name in the graph corresponds to the module from which the op was called.
# This means that when ops are invoked from the same module via functional calls or direct
# operations on tensors, these ops will have the SAME MODEL NAME associated with them.
# For example:
# t = t1 + t2
# t = F.relu(t)
# In this case the add operation and the ReLU operation will have the same name, which is
# derived from the module they're contained in.
#
# Another case where different ops will have the same module name is when a module is reused:
# out = self.conv1(x)
# out = self.relu(out) <=== First use of self.relu
# out = self.conv2(out)
# out = self.relu(out) <=== Second use of self.relu
# In this case the graph will have 2 distinct ReLU nodes, with the same scope name.
#
# Operators with the same name create very confusing graphs (in ResNet, for example),
# so we "unroll" them. # so we "unroll" them.
# Sometimes operations of different types have the same name, so we differentiate same_module_cnt = len(self.module_ops_map[module_name])
# using both name and type if same_module_cnt:
# (this happens, for example, when an operator is called via some functional API and new_op['name'] += "__" + str(same_module_cnt)
# not via a module) self.module_ops_map[module_name].append(new_op['name'])
same = [op for op in self.ops.values() if
op['orig-name'] + op['type'] == new_op['orig-name'] + new_op['type']] # Finally we register the new op in the ops collection
if len(same) > 0: msglogger.debug("new sgraph node - Scope name: {} ; Type: {} ; Display name {}".format(
new_op['name'] += "." + str(len(same)) new_op['orig-name'], new_op['type'], new_op['name']))
new_op['name'] = onnx_name_2_pytorch_name(new_op['name'], new_op['type'])
assert len(new_op['name']) > 0
if new_op['name'] in self.ops:
# This is a patch.
# ONNX names integrate the node type, while we don't (design bug).
# This means that while parsing the ONNX graph we might find two nodes with the "same" name.
# This patch increments the instance name, but this may break in the future.
new_op['name'] = increment_instance(new_op['name'])
self.ops[new_op['name']] = new_op self.ops[new_op['name']] = new_op
for input_ in node.inputs(): for input_ in node.inputs():
...@@ -151,11 +156,44 @@ class SummaryGraph(object): ...@@ -151,11 +156,44 @@ class SummaryGraph(object):
new_op['attrs'] = OrderedDict([(attr_name, node[attr_name]) for attr_name in node.attributeNames()]) new_op['attrs'] = OrderedDict([(attr_name, node[attr_name]) for attr_name in node.attributeNames()])
self.__merge_pad_avgpool()
self.add_macs_attr() self.add_macs_attr()
self.add_footprint_attr() self.add_footprint_attr()
self.add_arithmetic_intensity_attr() self.add_arithmetic_intensity_attr()
del model_clone del model_clone
def __merge_pad_avgpool(self):
""" The ONNX trace optimization converts average pool ops to a sequence of 2 operations: pad + pool.
This "quirk" makes makes it unnecessarily difficult to detect the connectivity between an average pool
op and its predecessor, and it doesn't serve any purpose in the context of SummaryGraph usages.
So we get rid of the pad op here.
"""
pad_op_name = None
for curr_op_name, curr_op in list(self.ops.items()):
curr_op_type = curr_op['type']
if curr_op_type == 'Pad':
pad_op_name = curr_op_name
else:
if pad_op_name and curr_op_type == 'AveragePool':
pad_op = self.ops[pad_op_name]
if pad_op['module-name'] != curr_op['module-name']:
continue
merged_op = OrderedDict(curr_op)
merged_op['name'] = pad_op_name
merged_op['inputs'] = pad_op['inputs']
self.ops[pad_op_name] = merged_op
self.ops.pop(curr_op_name)
self.module_ops_map[merged_op['module-name']].remove(curr_op_name)
sequence_input_idx = pad_op['inputs'][0]
first_edge = SummaryGraph.Edge(sequence_input_idx, pad_op_name)
idx = self.edges.index(first_edge)
del self.edges[idx:idx + 4]
self.edges.insert(idx, SummaryGraph.Edge(sequence_input_idx, pad_op_name))
self.edges.insert(idx + 1, SummaryGraph.Edge(pad_op_name, merged_op['outputs'][0]))
pad_op_name = None
def __create_op(self, onnx_node): def __create_op(self, onnx_node):
op = OrderedDict() op = OrderedDict()
op['name'] = onnx_node.scopeName() op['name'] = onnx_node.scopeName()
...@@ -280,19 +318,15 @@ class SummaryGraph(object): ...@@ -280,19 +318,15 @@ class SummaryGraph(object):
def find_param(self, data_name): def find_param(self, data_name):
return self.params.get(data_name, None) return self.params.get(data_name, None)
def predecessors(self, op, depth, done_list=None): def predecessors(self, node, depth, done_list=None):
"""Returns a list of <op>'s predecessors""" """Returns a list of <op>'s predecessors"""
if done_list is None: if done_list is None:
done_list = [] done_list = []
if isinstance(op, dict): node_name = node['name'] if isinstance(node, dict) else node
preds = [edge.src for edge in self.edges if (edge.dst == op['name'] and preds = [edge.src for edge in self.edges if (edge.dst == node_name and
edge.src not in done_list)] edge.src not in done_list)]
done_list += preds done_list += preds
else:
preds = [edge.src for edge in self.edges if (edge.dst == op and
edge.src not in done_list)]
done_list += preds
if depth == 1: if depth == 1:
ret = preds ret = preds
...@@ -348,16 +382,10 @@ class SummaryGraph(object): ...@@ -348,16 +382,10 @@ class SummaryGraph(object):
if done_list is None: if done_list is None:
done_list = [] done_list = []
if isinstance(node, dict): node_name = node['name'] if isinstance(node, dict) else node
# This is an operation node succs = [edge.dst for edge in self.edges if (edge.src == node_name and
succs = [edge.dst for edge in self.edges if (edge.src == node['name'] and edge.dst not in done_list)]
edge.dst not in done_list)] done_list += succs
done_list += succs
else:
# This is a data node
succs = [edge.dst for edge in self.edges if (edge.src == node and
edge.dst not in done_list)]
done_list += succs
if depth == 1: if depth == 1:
ret = succs ret = succs
...@@ -423,3 +451,45 @@ class SummaryGraph(object): ...@@ -423,3 +451,45 @@ class SummaryGraph(object):
sgraph_layer_name = distiller.denormalize_module_name( sgraph_layer_name = distiller.denormalize_module_name(
self._src_model, normalized_layer_name) self._src_model, normalized_layer_name)
yield sgraph_layer_name, param_name, param yield sgraph_layer_name, param_name, param
def adjacency_map(self, dedicated_modules_only=False):
"""Returns a mapping from each op in the graph to its immediate predecessors and successors.
The keys in the generated mapping are op names, and the values are instances of AdjacentsEntry.
The op names are "de-normalized", meaning they can be used directly with the underlying model's
named_modules(), for example.
Args:
dedicated_modules_only (bool): If set, the generated mapping will not include any ops that can't be
associated with a dedicated module within the underlying model. Examples of this will be
functional calls, such as "F.relu()", and tensor operations, such as "t3 = t1 + t2".
"""
adj_map = OrderedDict()
for op_name, op in self.ops.items():
def dedicated_module_check(n):
module_name = self.ops[distiller.normalize_module_name(n)]['module-name']
return len(self.module_ops_map[module_name]) == 1 or not dedicated_modules_only
if not dedicated_module_check(op_name):
continue
entry = AdjacentsEntry()
# Find the immediate preceding and succeeding modules. Depth of 1 gets us the
# input and output tensors, depth of 2 gets the actual modules
entry.predecessors = [n for n in self.predecessors(op, 2) if dedicated_module_check(n)]
entry.successors = [n for n in self.successors(op, 2) if dedicated_module_check(n)]
adj_map[distiller.denormalize_module_name(self._src_model, op_name)] = entry
return adj_map
class AdjacentsEntry(object):
def __init__(self):
self.predecessors = []
self.successors = []
def __repr__(self):
return 'Predecessors: {0} ; Successors: {1}'.format(self.predecessors, self.successors)
...@@ -49,7 +49,7 @@ def test_connectivity(): ...@@ -49,7 +49,7 @@ def test_connectivity():
assert g is not None assert g is not None
op_names = [op['name'] for op in g.ops.values()] op_names = [op['name'] for op in g.ops.values()]
assert 81 == len(op_names) assert len(op_names) == 80
edges = g.edges edges = g.edges
assert edges[0].src == '0' and edges[0].dst == 'conv1' assert edges[0].src == '0' and edges[0].dst == 'conv1'
...@@ -168,10 +168,9 @@ def test_named_params_layers(): ...@@ -168,10 +168,9 @@ def test_named_params_layers():
def test_onnx_name_2_pytorch_name(): def test_onnx_name_2_pytorch_name():
assert "layer3.0.relu1" == onnx_name_2_pytorch_name("ResNet/Sequential[layer3]/BasicBlock[0]/ReLU[relu].1", 'Relu') assert onnx_name_2_pytorch_name("ResNet/Sequential[layer3]/BasicBlock[0]/ReLU[relu]") == "layer3.0.relu"
assert "features.34" == onnx_name_2_pytorch_name('VGG/[features]/Sequential/Conv2d[34]', 'Conv') assert onnx_name_2_pytorch_name('VGG/[features]/Sequential/Conv2d[34]') == "features.34"
assert "Relu3" == onnx_name_2_pytorch_name('NameWithNoModule.3', 'Relu') assert onnx_name_2_pytorch_name('NameWithNoModule') == ''
#assert "features.module.34" == onnx_name_2_pytorch_name('VGG/DataParallel[features]/Sequential/Conv2d[34]', 'Conv')
def test_connectivity_summary(): def test_connectivity_summary():
...@@ -179,10 +178,10 @@ def test_connectivity_summary(): ...@@ -179,10 +178,10 @@ def test_connectivity_summary():
assert g is not None assert g is not None
summary = connectivity_summary(g) summary = connectivity_summary(g)
assert len(summary) == 81 assert len(summary) == 80
verbose_summary = connectivity_summary_verbose(g) verbose_summary = connectivity_summary_verbose(g)
assert len(verbose_summary) == 81 assert len(verbose_summary) == 80
def test_sg_macs(): def test_sg_macs():
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment