From 73b3b3cff33113c53e1aa396c09b041db7e79e90 Mon Sep 17 00:00:00 2001
From: Neta Zmora <31280975+nzmora@users.noreply.github.com>
Date: Mon, 8 Apr 2019 11:38:20 +0300
Subject: [PATCH] Fix issue #213 (#221)

Dropout layers were not handled properly in SummaryGraph, and
caused the indexing of layer names to change.
The root cause is that in ONNX uses the same node name for
Dropout and Linear layers that are processed in sequence.
ONNX nodes can be identified by three components: the ONNX
node name,  type, and instance.
In SummaryGraph we ignore the node type when naming a node.
Specifically in AlexNet, nodes the Dropout layers before a Linear
layer have the same node name and instance, and are only distinguished
by their type.  SummaryGraph, ignorant of the type, skipped the Dropout
layers and gave SG nodes the wrong name.  Thus 'classifier.0', which is
a Dropout node, became a Linear node.
The fix is not to ignore duplicate (node name, instance) pairs
by incrementing the instance.
---
 distiller/summary_graph.py | 18 ++++++++++++++++++
 tests/test_summarygraph.py |  6 +++---
 2 files changed, 21 insertions(+), 3 deletions(-)

diff --git a/distiller/summary_graph.py b/distiller/summary_graph.py
index 073bd27..2d2dab6 100755
--- a/distiller/summary_graph.py
+++ b/distiller/summary_graph.py
@@ -49,6 +49,18 @@ def onnx_name_2_pytorch_name(name, op_type):
     return new_name
 
 
+def increment_instance(node_name):
+    """Increment the instance number of a given node"""
+    try:
+        # There is an assumption here that the last character in node_name is the node instance (an integer),
+        # and that it is between 0-9 (i.e. a digit)
+        base_name = node_name[:-1]
+        suffix = str(int(node_name[-1]) + 1)
+        return base_name + suffix
+    except ValueError:
+        return node_name + ".0"
+
+
 class SummaryGraph(object):
     """We use Pytorch's JIT tracer to run a forward pass and generate a trace graph, which
     is an internal representation of the model.  We then use ONNX to "clean" this
@@ -119,6 +131,12 @@ class SummaryGraph(object):
                 new_op['name'] = onnx_name_2_pytorch_name(new_op['name'], new_op['type'])
                 assert len(new_op['name']) > 0
 
+                if new_op['name'] in self.ops:
+                    # This is a patch.
+                    # ONNX names integrate the node type, while we don't (design bug).
+                    # This means that while parsing the ONNX graph we might find two nodes with the "same" name.
+                    # This patch increments the instance name, but this may break in the future.
+                    new_op['name'] = increment_instance(new_op['name'])
                 self.ops[new_op['name']] = new_op
 
                 for input_ in node.inputs():
diff --git a/tests/test_summarygraph.py b/tests/test_summarygraph.py
index f9342ef..0d378fa 100755
--- a/tests/test_summarygraph.py
+++ b/tests/test_summarygraph.py
@@ -59,7 +59,7 @@ def test_connectivity():
     assert g is not None
 
     op_names = [op['name'] for op in g.ops.values()]
-    assert 80 == len(op_names)
+    assert 81 == len(op_names)
 
     edges = g.edges
     assert edges[0].src == '0' and edges[0].dst == 'conv1'
@@ -173,10 +173,10 @@ def test_connectivity_summary():
     assert g is not None
 
     summary = connectivity_summary(g)
-    assert len(summary) == 80
+    assert len(summary) == 81
 
     verbose_summary = connectivity_summary_verbose(g)
-    assert len(verbose_summary) == 80
+    assert len(verbose_summary) == 81
 
 
 if __name__ == '__main__':
-- 
GitLab