From f1f0d7531cba023034490f789844d2c490182537 Mon Sep 17 00:00:00 2001 From: Guy Jacob <guy.jacob@intel.com> Date: Wed, 15 May 2019 10:20:53 +0300 Subject: [PATCH] SummaryGraph changes: _force_outplace + OrderedDicts * Set _force_outplace when calling get_trace_graph. This is a workaround for losing scope information for certain in-place operations * Switch all dicts to OrderedDicts --- distiller/summary_graph.py | 15 ++++++++------- 1 file changed, 8 insertions(+), 7 deletions(-) diff --git a/distiller/summary_graph.py b/distiller/summary_graph.py index 4469294..d3facd5 100755 --- a/distiller/summary_graph.py +++ b/distiller/summary_graph.py @@ -21,6 +21,7 @@ import collections import torch import torch.jit as jit import logging +from collections import OrderedDict msglogger = logging.getLogger() @@ -99,17 +100,17 @@ class SummaryGraph(object): device = next(model_clone.parameters()).device dummy_input = distiller.convert_tensors_recursively_to(dummy_input, device=device) - trace, _ = jit.get_trace_graph(model_clone, dummy_input) + trace, _ = jit.get_trace_graph(model_clone, dummy_input, _force_outplace=True) # Let ONNX do the heavy lifting: fusing the convolution nodes; fusing the nodes # composing a GEMM operation; etc. torch.onnx._optimize_trace(trace, torch.onnx.OperatorExportTypes.ONNX) graph = trace.graph() - self.ops = {} - self.params = {} + self.ops = OrderedDict() + self.params = OrderedDict() self.edges = [] - self.temp = {} + self.temp = OrderedDict() in_out = list(graph.inputs()) + list(graph.outputs()) for param in in_out: @@ -148,7 +149,7 @@ class SummaryGraph(object): self.__add_output(new_op, output) self.edges.append(SummaryGraph.Edge(new_op['name'], output.uniqueName())) - new_op['attrs'] = {attr_name: node[attr_name] for attr_name in node.attributeNames()} + new_op['attrs'] = OrderedDict([(attr_name, node[attr_name]) for attr_name in node.attributeNames()]) self.add_macs_attr() self.add_footprint_attr() @@ -156,7 +157,7 @@ class SummaryGraph(object): del model_clone def __create_op(self, onnx_node): - op = {} + op = OrderedDict() op['name'] = onnx_node.scopeName() op['orig-name'] = onnx_node.scopeName() op['type'] = onnx_node.kind().lstrip('::onnx') @@ -188,7 +189,7 @@ class SummaryGraph(object): return param def __tensor_desc(self, n): - tensor = {} + tensor = OrderedDict() tensor['id'] = n.uniqueName() try: # try parsing the FM tensor type. For example: Float(1, 64, 8, 8) -- GitLab