From b60a33efdaab3057788ecdb42ef144e27bb0fb74 Mon Sep 17 00:00:00 2001
From: Guy Jacob <guy.jacob@intel.com>
Date: Sun, 23 Jun 2019 13:36:31 +0300
Subject: [PATCH] SummaryGraph: Add adjacency map + numerous changes (#291)

* Adjacency map - map from each op to its predecessor and successor ops
* More robust handling of Gemm nodes scope names (instead of
  increment_instance())
* More consistent handling of ops with the same scope name
* Handle pad + avg pool sequences generated by ONNX trace optimization
  (results in one less op in the graph, hence the changes in tests)
* Minor refactoring in predecessors() and successors() functions
---
 distiller/summary_graph.py | 206 +++++++++++++++++++++++++------------
 tests/test_summarygraph.py |  13 ++-
 2 files changed, 144 insertions(+), 75 deletions(-)

diff --git a/distiller/summary_graph.py b/distiller/summary_graph.py
index a5f8ce4..a88ff2c 100755
--- a/distiller/summary_graph.py
+++ b/distiller/summary_graph.py
@@ -21,45 +21,20 @@ import collections
 import torch
 import torch.jit as jit
 import logging
-from collections import OrderedDict
+from collections import OrderedDict, defaultdict
 msglogger = logging.getLogger()
 
 
-def onnx_name_2_pytorch_name(name, op_type):
+def onnx_name_2_pytorch_name(name):
     # Convert a layer's name from an ONNX name, to a PyTorch name
     # For example:
-    #   ResNet/Sequential[layer3]/BasicBlock[0]/ReLU[relu].1 ==> layer3.0.relu.1
+    #   ResNet/Sequential[layer3]/BasicBlock[0]/ReLU[relu] ==> layer3.0.relu
 
-    # First see if there's an instance identifier
-    instance = ''
-    if name.find('.') >= 0:
-        instance = name[name.find('.')+1:]
-
-    # Next, split by square brackets
+    # Split by square brackets
     name_parts = re.findall('\[.*?\]', name)
     name_parts = [part[1:-1] for part in name_parts]
 
-    # If name doesn't have the pattern above, it probably means the op was called via
-    # some functional API and not via a module. Couple of examples:
-    #   x = x.view(...)
-    #   x = F.relu(x)
-    # In this case, to have a meaningful name, we use the op type
-    new_name = ('.'.join(name_parts) if len(name_parts) > 0 else op_type) + instance
-
-    msglogger.debug("new sgraph node {} {} {}".format(name, op_type, new_name))
-    return new_name
-
-
-def increment_instance(node_name):
-    """Increment the instance number of a given node"""
-    try:
-        # There is an assumption here that the last character in node_name is the node instance (an integer),
-        # and that it is between 0-9 (i.e. a digit)
-        base_name = node_name[:-1]
-        suffix = str(int(node_name[-1]) + 1)
-        return base_name + suffix
-    except ValueError:
-        return node_name + ".0"
+    return '.'.join(name_parts)
 
 
 class SummaryGraph(object):
@@ -102,12 +77,20 @@ class SummaryGraph(object):
             dummy_input = distiller.convert_tensors_recursively_to(dummy_input, device=device)
             trace, _ = jit.get_trace_graph(model_clone, dummy_input, _force_outplace=True)
 
+            # ONNX trace optimization has issues with Gemm ops (aka "Linear" / "addmm" / "FC"), where
+            # Gemm nodes get the scope name of the last non-Gemm node that came before them. This can make
+            # it impossible, in some cases, to derive the connectivity of the model using the original
+            # module names. So we save the scope names for these nodes from the un-optimized trace.
+            aten_addmm_nodes_scope_names = [n.scopeName() for n in trace.graph().nodes() if n.kind() == 'aten::addmm']
+            onnx_gemm_count = 0
+
             # Let ONNX do the heavy lifting: fusing the convolution nodes; fusing the nodes
             # composing a GEMM operation; etc.
             torch.onnx._optimize_trace(trace, torch.onnx.OperatorExportTypes.ONNX)
 
             graph = trace.graph()
             self.ops = OrderedDict()
+            self.module_ops_map = defaultdict(list)
             self.params = OrderedDict()
             self.edges = []
             self.temp = OrderedDict()
@@ -119,26 +102,48 @@ class SummaryGraph(object):
             for node in graph.nodes():
                 new_op = self.__create_op(node)
 
-                # Operators with the same name create very confusing graphs (Resnet, for example),
+                # Here we apply the workaround to the Gemm nodes scope name issue mentioned above
+                if new_op['type'] == 'Gemm':
+                    new_op['orig-name'] = aten_addmm_nodes_scope_names[onnx_gemm_count]
+                    new_op['name'] = new_op['orig-name']
+                    onnx_gemm_count += 1
+
+                # Convert the graph node's scope name to a PyTorch module name
+                module_name = onnx_name_2_pytorch_name(new_op['orig-name'])
+                new_op['module-name'] = module_name
+                if len(module_name) == 0:
+                    # Special case where the module name is an empty string - this happens
+                    # when the op is called from the "top-level" of the model
+                    new_op['name'] = 'top_level_op'
+                else:
+                    new_op['name'] = module_name
+
+                # The node's scope name in the graph corresponds to the module from which the op was called.
+                # This means that when ops are invoked from the same module via functional calls or direct
+                # operations on tensors, these ops will have the SAME MODEL NAME associated with them.
+                # For example:
+                #   t = t1 + t2
+                #   t = F.relu(t)
+                # In this case the add operation and the ReLU operation will have the same name, which is
+                # derived from the module they're contained in.
+                #
+                # Another case where different ops will have the same module name is when a module is reused:
+                #   out = self.conv1(x)
+                #   out = self.relu(out)    <=== First use of self.relu
+                #   out = self.conv2(out)
+                #   out = self.relu(out)    <=== Second use of self.relu
+                # In this case the graph will have 2 distinct ReLU nodes, with the same scope name.
+                #
+                # Operators with the same name create very confusing graphs (in ResNet, for example),
                 # so we "unroll" them.
-                # Sometimes operations of different types have the same name, so we differentiate
-                # using both name and type
-                # (this happens, for example, when an operator is called via some functional API and
-                # not via a module)
-                same = [op for op in self.ops.values() if
-                        op['orig-name'] + op['type'] == new_op['orig-name'] + new_op['type']]
-                if len(same) > 0:
-                    new_op['name'] += "." + str(len(same))
-
-                new_op['name'] = onnx_name_2_pytorch_name(new_op['name'], new_op['type'])
-                assert len(new_op['name']) > 0
-
-                if new_op['name'] in self.ops:
-                    # This is a patch.
-                    # ONNX names integrate the node type, while we don't (design bug).
-                    # This means that while parsing the ONNX graph we might find two nodes with the "same" name.
-                    # This patch increments the instance name, but this may break in the future.
-                    new_op['name'] = increment_instance(new_op['name'])
+                same_module_cnt = len(self.module_ops_map[module_name])
+                if same_module_cnt:
+                    new_op['name'] += "__" + str(same_module_cnt)
+                self.module_ops_map[module_name].append(new_op['name'])
+
+                # Finally we register the new op in the ops collection
+                msglogger.debug("new sgraph node - Scope name: {} ; Type: {} ; Display name {}".format(
+                    new_op['orig-name'], new_op['type'], new_op['name']))
                 self.ops[new_op['name']] = new_op
 
                 for input_ in node.inputs():
@@ -151,11 +156,44 @@ class SummaryGraph(object):
 
                 new_op['attrs'] = OrderedDict([(attr_name, node[attr_name]) for attr_name in node.attributeNames()])
 
+        self.__merge_pad_avgpool()
         self.add_macs_attr()
         self.add_footprint_attr()
         self.add_arithmetic_intensity_attr()
         del model_clone
 
+    def __merge_pad_avgpool(self):
+        """ The ONNX trace optimization converts average pool ops to a sequence of 2 operations: pad + pool.
+        This "quirk" makes makes it unnecessarily difficult to detect the connectivity between an average pool
+        op and its predecessor, and it doesn't serve any purpose in the context of SummaryGraph usages.
+        So we get rid of the pad op here.
+        """
+        pad_op_name = None
+        for curr_op_name, curr_op in list(self.ops.items()):
+            curr_op_type = curr_op['type']
+            if curr_op_type == 'Pad':
+                pad_op_name = curr_op_name
+            else:
+                if pad_op_name and curr_op_type == 'AveragePool':
+                    pad_op = self.ops[pad_op_name]
+                    if pad_op['module-name'] != curr_op['module-name']:
+                        continue
+                    merged_op = OrderedDict(curr_op)
+                    merged_op['name'] = pad_op_name
+                    merged_op['inputs'] = pad_op['inputs']
+                    self.ops[pad_op_name] = merged_op
+                    self.ops.pop(curr_op_name)
+                    self.module_ops_map[merged_op['module-name']].remove(curr_op_name)
+
+                    sequence_input_idx = pad_op['inputs'][0]
+                    first_edge = SummaryGraph.Edge(sequence_input_idx, pad_op_name)
+                    idx = self.edges.index(first_edge)
+                    del self.edges[idx:idx + 4]
+                    self.edges.insert(idx, SummaryGraph.Edge(sequence_input_idx, pad_op_name))
+                    self.edges.insert(idx + 1, SummaryGraph.Edge(pad_op_name, merged_op['outputs'][0]))
+
+                pad_op_name = None
+
     def __create_op(self, onnx_node):
         op = OrderedDict()
         op['name'] = onnx_node.scopeName()
@@ -280,19 +318,15 @@ class SummaryGraph(object):
     def find_param(self, data_name):
         return self.params.get(data_name, None)
 
-    def predecessors(self, op, depth, done_list=None):
+    def predecessors(self, node, depth, done_list=None):
         """Returns a list of <op>'s predecessors"""
         if done_list is None:
             done_list = []
 
-        if isinstance(op, dict):
-            preds = [edge.src for edge in self.edges if (edge.dst == op['name'] and
-                                                         edge.src not in done_list)]
-            done_list += preds
-        else:
-            preds = [edge.src for edge in self.edges if (edge.dst == op and
-                                                         edge.src not in done_list)]
-            done_list += preds
+        node_name = node['name'] if isinstance(node, dict) else node
+        preds = [edge.src for edge in self.edges if (edge.dst == node_name and
+                                                     edge.src not in done_list)]
+        done_list += preds
 
         if depth == 1:
             ret = preds
@@ -348,16 +382,10 @@ class SummaryGraph(object):
         if done_list is None:
             done_list = []
 
-        if isinstance(node, dict):
-            # This is an operation node
-            succs = [edge.dst for edge in self.edges if (edge.src == node['name'] and
-                                                         edge.dst not in done_list)]
-            done_list += succs
-        else:
-            # This is a data node
-            succs = [edge.dst for edge in self.edges if (edge.src == node and
-                                                         edge.dst not in done_list)]
-            done_list += succs
+        node_name = node['name'] if isinstance(node, dict) else node
+        succs = [edge.dst for edge in self.edges if (edge.src == node_name and
+                                                     edge.dst not in done_list)]
+        done_list += succs
 
         if depth == 1:
             ret = succs
@@ -423,3 +451,45 @@ class SummaryGraph(object):
             sgraph_layer_name = distiller.denormalize_module_name(
                 self._src_model, normalized_layer_name)
             yield sgraph_layer_name, param_name, param
+
+    def adjacency_map(self, dedicated_modules_only=False):
+        """Returns a mapping from each op in the graph to its immediate predecessors and successors.
+
+        The keys in the generated mapping are op names, and the values are instances of AdjacentsEntry.
+
+        The op names are "de-normalized", meaning they can be used directly with the underlying model's
+        named_modules(), for example.
+
+        Args:
+            dedicated_modules_only (bool): If set, the generated mapping will not include any ops that can't be
+              associated with a dedicated module within the underlying model. Examples of this will be
+              functional calls, such as "F.relu()", and tensor operations, such as "t3 = t1 + t2".
+        """
+        adj_map = OrderedDict()
+
+        for op_name, op in self.ops.items():
+            def dedicated_module_check(n):
+                module_name = self.ops[distiller.normalize_module_name(n)]['module-name']
+                return len(self.module_ops_map[module_name]) == 1 or not dedicated_modules_only
+
+            if not dedicated_module_check(op_name):
+                continue
+
+            entry = AdjacentsEntry()
+            # Find the immediate preceding and succeeding modules. Depth of 1 gets us the
+            # input and output tensors, depth of 2 gets the actual modules
+            entry.predecessors = [n for n in self.predecessors(op, 2) if dedicated_module_check(n)]
+            entry.successors = [n for n in self.successors(op, 2) if dedicated_module_check(n)]
+
+            adj_map[distiller.denormalize_module_name(self._src_model, op_name)] = entry
+
+        return adj_map
+
+
+class AdjacentsEntry(object):
+    def __init__(self):
+        self.predecessors = []
+        self.successors = []
+
+    def __repr__(self):
+        return 'Predecessors: {0} ; Successors: {1}'.format(self.predecessors, self.successors)
diff --git a/tests/test_summarygraph.py b/tests/test_summarygraph.py
index 55ce88b..c92fe4f 100755
--- a/tests/test_summarygraph.py
+++ b/tests/test_summarygraph.py
@@ -49,7 +49,7 @@ def test_connectivity():
     assert g is not None
 
     op_names = [op['name'] for op in g.ops.values()]
-    assert 81 == len(op_names)
+    assert len(op_names) == 80
 
     edges = g.edges
     assert edges[0].src == '0' and edges[0].dst == 'conv1'
@@ -168,10 +168,9 @@ def test_named_params_layers():
 
 
 def test_onnx_name_2_pytorch_name():
-    assert "layer3.0.relu1" == onnx_name_2_pytorch_name("ResNet/Sequential[layer3]/BasicBlock[0]/ReLU[relu].1", 'Relu')
-    assert "features.34" == onnx_name_2_pytorch_name('VGG/[features]/Sequential/Conv2d[34]', 'Conv')
-    assert "Relu3" == onnx_name_2_pytorch_name('NameWithNoModule.3', 'Relu')
-    #assert "features.module.34" == onnx_name_2_pytorch_name('VGG/DataParallel[features]/Sequential/Conv2d[34]', 'Conv')
+    assert onnx_name_2_pytorch_name("ResNet/Sequential[layer3]/BasicBlock[0]/ReLU[relu]") == "layer3.0.relu"
+    assert onnx_name_2_pytorch_name('VGG/[features]/Sequential/Conv2d[34]') == "features.34"
+    assert onnx_name_2_pytorch_name('NameWithNoModule') == ''
 
 
 def test_connectivity_summary():
@@ -179,10 +178,10 @@ def test_connectivity_summary():
     assert g is not None
 
     summary = connectivity_summary(g)
-    assert len(summary) == 81
+    assert len(summary) == 80
 
     verbose_summary = connectivity_summary_verbose(g)
-    assert len(verbose_summary) == 81
+    assert len(verbose_summary) == 80
 
 
 def test_sg_macs():
-- 
GitLab