diff --git a/distiller/summary_graph.py b/distiller/summary_graph.py index 073bd27b680a8bcb170673cbf61ae881053ec5bb..2d2dab63b81f4bd1fe9b05e26bee75bf7d58d019 100755 --- a/distiller/summary_graph.py +++ b/distiller/summary_graph.py @@ -49,6 +49,18 @@ def onnx_name_2_pytorch_name(name, op_type): return new_name +def increment_instance(node_name): + """Increment the instance number of a given node""" + try: + # There is an assumption here that the last character in node_name is the node instance (an integer), + # and that it is between 0-9 (i.e. a digit) + base_name = node_name[:-1] + suffix = str(int(node_name[-1]) + 1) + return base_name + suffix + except ValueError: + return node_name + ".0" + + class SummaryGraph(object): """We use Pytorch's JIT tracer to run a forward pass and generate a trace graph, which is an internal representation of the model. We then use ONNX to "clean" this @@ -119,6 +131,12 @@ class SummaryGraph(object): new_op['name'] = onnx_name_2_pytorch_name(new_op['name'], new_op['type']) assert len(new_op['name']) > 0 + if new_op['name'] in self.ops: + # This is a patch. + # ONNX names integrate the node type, while we don't (design bug). + # This means that while parsing the ONNX graph we might find two nodes with the "same" name. + # This patch increments the instance name, but this may break in the future. + new_op['name'] = increment_instance(new_op['name']) self.ops[new_op['name']] = new_op for input_ in node.inputs(): diff --git a/tests/test_summarygraph.py b/tests/test_summarygraph.py index f9342ef31495a2a3397b0871b5c82967d34908cb..0d378fa9f18df827cbde0af82ce785f6b6d71d7d 100755 --- a/tests/test_summarygraph.py +++ b/tests/test_summarygraph.py @@ -59,7 +59,7 @@ def test_connectivity(): assert g is not None op_names = [op['name'] for op in g.ops.values()] - assert 80 == len(op_names) + assert 81 == len(op_names) edges = g.edges assert edges[0].src == '0' and edges[0].dst == 'conv1' @@ -173,10 +173,10 @@ def test_connectivity_summary(): assert g is not None summary = connectivity_summary(g) - assert len(summary) == 80 + assert len(summary) == 81 verbose_summary = connectivity_summary_verbose(g) - assert len(verbose_summary) == 80 + assert len(verbose_summary) == 81 if __name__ == '__main__':