From f1f0d7531cba023034490f789844d2c490182537 Mon Sep 17 00:00:00 2001
From: Guy Jacob <guy.jacob@intel.com>
Date: Wed, 15 May 2019 10:20:53 +0300
Subject: [PATCH] SummaryGraph changes: _force_outplace + OrderedDicts

* Set _force_outplace when calling get_trace_graph. This is a
  workaround for losing scope information for certain in-place
  operations
* Switch all dicts to OrderedDicts
---
 distiller/summary_graph.py | 15 ++++++++-------
 1 file changed, 8 insertions(+), 7 deletions(-)

diff --git a/distiller/summary_graph.py b/distiller/summary_graph.py
index 4469294..d3facd5 100755
--- a/distiller/summary_graph.py
+++ b/distiller/summary_graph.py
@@ -21,6 +21,7 @@ import collections
 import torch
 import torch.jit as jit
 import logging
+from collections import OrderedDict
 msglogger = logging.getLogger()
 
 
@@ -99,17 +100,17 @@ class SummaryGraph(object):
             
             device = next(model_clone.parameters()).device
             dummy_input = distiller.convert_tensors_recursively_to(dummy_input, device=device)
-            trace, _ = jit.get_trace_graph(model_clone, dummy_input)
+            trace, _ = jit.get_trace_graph(model_clone, dummy_input, _force_outplace=True)
 
             # Let ONNX do the heavy lifting: fusing the convolution nodes; fusing the nodes
             # composing a GEMM operation; etc.
             torch.onnx._optimize_trace(trace, torch.onnx.OperatorExportTypes.ONNX)
 
             graph = trace.graph()
-            self.ops = {}
-            self.params = {}
+            self.ops = OrderedDict()
+            self.params = OrderedDict()
             self.edges = []
-            self.temp = {}
+            self.temp = OrderedDict()
 
             in_out = list(graph.inputs()) + list(graph.outputs())
             for param in in_out:
@@ -148,7 +149,7 @@ class SummaryGraph(object):
                     self.__add_output(new_op, output)
                     self.edges.append(SummaryGraph.Edge(new_op['name'], output.uniqueName()))
 
-                new_op['attrs'] = {attr_name: node[attr_name] for attr_name in node.attributeNames()}
+                new_op['attrs'] = OrderedDict([(attr_name, node[attr_name]) for attr_name in node.attributeNames()])
 
         self.add_macs_attr()
         self.add_footprint_attr()
@@ -156,7 +157,7 @@ class SummaryGraph(object):
         del model_clone
 
     def __create_op(self, onnx_node):
-        op = {}
+        op = OrderedDict()
         op['name'] = onnx_node.scopeName()
         op['orig-name'] = onnx_node.scopeName()
         op['type'] = onnx_node.kind().lstrip('::onnx')
@@ -188,7 +189,7 @@ class SummaryGraph(object):
         return param
 
     def __tensor_desc(self, n):
-        tensor = {}
+        tensor = OrderedDict()
         tensor['id'] = n.uniqueName()
         try:
             # try parsing the FM tensor type.  For example: Float(1, 64, 8, 8)
-- 
GitLab