Skip to content
Snippets Groups Projects
Commit c12a9b26 authored by Yifan Zhao's avatar Yifan Zhao
Browse files

Inferred shape information now available

parent 50cd9c21
No related branches found
No related tags found
No related merge requests found
......@@ -7,7 +7,7 @@ import networkx as nx
import onnx
from . import graph_ir as g
from .onnx_attr import node_attr_to_dict
from .onnx_attr import node_attr_to_dict, node_to_shape
GraphT = onnx.GraphProto
NodeT = onnx.NodeProto
......@@ -109,7 +109,9 @@ class DFG(object):
ret_graph = nx.DiGraph()
onnx_defs, onnx_uses = def_use(graph.node)
tensors = extract_tensors_from_graph(graph)
ret_graph.add_nodes_from(graph.node)
node_shape = node_to_shape(graph)
node_and_attr = [(n, {'shape': shape}) for n, shape in node_shape.items()]
ret_graph.add_nodes_from(node_and_attr)
for onnx_value_name, use_nodes in onnx_uses.items():
def_node = onnx_defs.get(onnx_value_name)
if def_node is None:
......
from typing import Tuple
from typing import Dict, List, Optional, Tuple
import numpy as np
from onnx import AttributeProto, NodeProto, TensorProto
from onnx import AttributeProto, NodeProto, TensorProto, GraphProto, TensorShapeProto
def throw_ctor(ty):
......@@ -81,3 +81,18 @@ def parse_node_attr(onnx_attr: AttributeProto) -> Tuple[str, object]:
def node_attr_to_dict(onnx_node: NodeProto):
return {attr.name: parse_node_attr(attr) for attr in onnx_node.attribute}
def node_to_shape(onnx_graph: GraphProto) -> Dict[NodeProto, Optional[List[int]]]:
def parse_shape(shape: TensorShapeProto) -> List[int]:
return [dim.dim_value for dim in shape.dim]
def unique_output_name(node: NodeProto) -> str:
if len(node.output) != 1:
raise ValueError(f"Node {node} has more than 1 outputs")
return node.output[0]
out_name_to_shape = {
vi.name: parse_shape(vi.type.tensor_type.shape) for vi in onnx_graph.value_info
}
return {n: out_name_to_shape.get(unique_output_name(n)) for n in onnx_graph.node}
......@@ -39,6 +39,7 @@ def compile(
model = onnx.load(onnx_file)
if opset is not None:
model = check_version(model, opset)
model = onnx.shape_inference.infer_shapes(model)
dfg = DFG(model.graph)
if hpvmc:
hpvm_code_gen = HpvmCodeGen(dfg, output_dir, input_size, batch_size, prefix)
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment