From c12a9b2629aca5e9c6a9e60998118f57d17f1280 Mon Sep 17 00:00:00 2001
From: Yifan Zhao <yifanz16@illinois.edu>
Date: Tue, 12 Jan 2021 19:12:32 -0600
Subject: [PATCH] Inferred shape information now available

---
 .../onnx_frontend/frontend/graph_builder.py   |  6 ++++--
 .../onnx_frontend/frontend/onnx_attr.py       | 19 +++++++++++++++++--
 hpvm/projects/onnx_frontend/main.py           |  1 +
 3 files changed, 22 insertions(+), 4 deletions(-)

diff --git a/hpvm/projects/onnx_frontend/frontend/graph_builder.py b/hpvm/projects/onnx_frontend/frontend/graph_builder.py
index 6452508e4a..0d468ac19e 100644
--- a/hpvm/projects/onnx_frontend/frontend/graph_builder.py
+++ b/hpvm/projects/onnx_frontend/frontend/graph_builder.py
@@ -7,7 +7,7 @@ import networkx as nx
 import onnx
 
 from . import graph_ir as g
-from .onnx_attr import node_attr_to_dict
+from .onnx_attr import node_attr_to_dict, node_to_shape
 
 GraphT = onnx.GraphProto
 NodeT = onnx.NodeProto
@@ -109,7 +109,9 @@ class DFG(object):
         ret_graph = nx.DiGraph()
         onnx_defs, onnx_uses = def_use(graph.node)
         tensors = extract_tensors_from_graph(graph)
-        ret_graph.add_nodes_from(graph.node)
+        node_shape = node_to_shape(graph)
+        node_and_attr = [(n, {'shape': shape}) for n, shape in node_shape.items()]
+        ret_graph.add_nodes_from(node_and_attr)
         for onnx_value_name, use_nodes in onnx_uses.items():
             def_node = onnx_defs.get(onnx_value_name)
             if def_node is None:
diff --git a/hpvm/projects/onnx_frontend/frontend/onnx_attr.py b/hpvm/projects/onnx_frontend/frontend/onnx_attr.py
index 7434dccbbb..a43961b024 100644
--- a/hpvm/projects/onnx_frontend/frontend/onnx_attr.py
+++ b/hpvm/projects/onnx_frontend/frontend/onnx_attr.py
@@ -1,7 +1,7 @@
-from typing import Tuple
+from typing import Dict, List, Optional, Tuple
 
 import numpy as np
-from onnx import AttributeProto, NodeProto, TensorProto
+from onnx import AttributeProto, NodeProto, TensorProto, GraphProto, TensorShapeProto
 
 
 def throw_ctor(ty):
@@ -81,3 +81,18 @@ def parse_node_attr(onnx_attr: AttributeProto) -> Tuple[str, object]:
 
 def node_attr_to_dict(onnx_node: NodeProto):
     return {attr.name: parse_node_attr(attr) for attr in onnx_node.attribute}
+
+
+def node_to_shape(onnx_graph: GraphProto) -> Dict[NodeProto, Optional[List[int]]]:
+    def parse_shape(shape: TensorShapeProto) -> List[int]:
+        return [dim.dim_value for dim in shape.dim]
+
+    def unique_output_name(node: NodeProto) -> str:
+        if len(node.output) != 1:
+            raise ValueError(f"Node {node} has more than 1 outputs")
+        return node.output[0]
+
+    out_name_to_shape = {
+        vi.name: parse_shape(vi.type.tensor_type.shape) for vi in onnx_graph.value_info
+    }
+    return {n: out_name_to_shape.get(unique_output_name(n)) for n in onnx_graph.node}
diff --git a/hpvm/projects/onnx_frontend/main.py b/hpvm/projects/onnx_frontend/main.py
index 184877082b..7dc23d4873 100644
--- a/hpvm/projects/onnx_frontend/main.py
+++ b/hpvm/projects/onnx_frontend/main.py
@@ -39,6 +39,7 @@ def compile(
     model = onnx.load(onnx_file)
     if opset is not None:
         model = check_version(model, opset)
+    model = onnx.shape_inference.infer_shapes(model)
     dfg = DFG(model.graph)
     if hpvmc:
         hpvm_code_gen = HpvmCodeGen(dfg, output_dir, input_size, batch_size, prefix)
-- 
GitLab