From c12a9b2629aca5e9c6a9e60998118f57d17f1280 Mon Sep 17 00:00:00 2001 From: Yifan Zhao <yifanz16@illinois.edu> Date: Tue, 12 Jan 2021 19:12:32 -0600 Subject: [PATCH] Inferred shape information now available --- .../onnx_frontend/frontend/graph_builder.py | 6 ++++-- .../onnx_frontend/frontend/onnx_attr.py | 19 +++++++++++++++++-- hpvm/projects/onnx_frontend/main.py | 1 + 3 files changed, 22 insertions(+), 4 deletions(-) diff --git a/hpvm/projects/onnx_frontend/frontend/graph_builder.py b/hpvm/projects/onnx_frontend/frontend/graph_builder.py index 6452508e4a..0d468ac19e 100644 --- a/hpvm/projects/onnx_frontend/frontend/graph_builder.py +++ b/hpvm/projects/onnx_frontend/frontend/graph_builder.py @@ -7,7 +7,7 @@ import networkx as nx import onnx from . import graph_ir as g -from .onnx_attr import node_attr_to_dict +from .onnx_attr import node_attr_to_dict, node_to_shape GraphT = onnx.GraphProto NodeT = onnx.NodeProto @@ -109,7 +109,9 @@ class DFG(object): ret_graph = nx.DiGraph() onnx_defs, onnx_uses = def_use(graph.node) tensors = extract_tensors_from_graph(graph) - ret_graph.add_nodes_from(graph.node) + node_shape = node_to_shape(graph) + node_and_attr = [(n, {'shape': shape}) for n, shape in node_shape.items()] + ret_graph.add_nodes_from(node_and_attr) for onnx_value_name, use_nodes in onnx_uses.items(): def_node = onnx_defs.get(onnx_value_name) if def_node is None: diff --git a/hpvm/projects/onnx_frontend/frontend/onnx_attr.py b/hpvm/projects/onnx_frontend/frontend/onnx_attr.py index 7434dccbbb..a43961b024 100644 --- a/hpvm/projects/onnx_frontend/frontend/onnx_attr.py +++ b/hpvm/projects/onnx_frontend/frontend/onnx_attr.py @@ -1,7 +1,7 @@ -from typing import Tuple +from typing import Dict, List, Optional, Tuple import numpy as np -from onnx import AttributeProto, NodeProto, TensorProto +from onnx import AttributeProto, NodeProto, TensorProto, GraphProto, TensorShapeProto def throw_ctor(ty): @@ -81,3 +81,18 @@ def parse_node_attr(onnx_attr: AttributeProto) -> Tuple[str, object]: def node_attr_to_dict(onnx_node: NodeProto): return {attr.name: parse_node_attr(attr) for attr in onnx_node.attribute} + + +def node_to_shape(onnx_graph: GraphProto) -> Dict[NodeProto, Optional[List[int]]]: + def parse_shape(shape: TensorShapeProto) -> List[int]: + return [dim.dim_value for dim in shape.dim] + + def unique_output_name(node: NodeProto) -> str: + if len(node.output) != 1: + raise ValueError(f"Node {node} has more than 1 outputs") + return node.output[0] + + out_name_to_shape = { + vi.name: parse_shape(vi.type.tensor_type.shape) for vi in onnx_graph.value_info + } + return {n: out_name_to_shape.get(unique_output_name(n)) for n in onnx_graph.node} diff --git a/hpvm/projects/onnx_frontend/main.py b/hpvm/projects/onnx_frontend/main.py index 184877082b..7dc23d4873 100644 --- a/hpvm/projects/onnx_frontend/main.py +++ b/hpvm/projects/onnx_frontend/main.py @@ -39,6 +39,7 @@ def compile( model = onnx.load(onnx_file) if opset is not None: model = check_version(model, opset) + model = onnx.shape_inference.infer_shapes(model) dfg = DFG(model.graph) if hpvmc: hpvm_code_gen = HpvmCodeGen(dfg, output_dir, input_size, batch_size, prefix) -- GitLab