Skip to content
Snippets Groups Projects
Commit f3dea9d3 authored by Yifan Zhao's avatar Yifan Zhao
Browse files

Code cleaning, type annotation

parent c3ade693
No related branches found
No related tags found
No related merge requests found
from collections import defaultdict
from os import PathLike
from pathlib import Path
from typing import Dict, List, Optional, Tuple
import networkx as nx
import onnx
import graph_ir as g
from tensor import InputTensor, WeightTensor
from collections import defaultdict
from tensor import InputTensor, Tensor, WeightTensor
ModelT = onnx.ModelProto
GraphT = onnx.GraphProto
NodeT = onnx.NodeProto
class GraphBuilder(object):
def __init__(self, model, shape):
class GraphBuilder:
def __init__(self, model: ModelT, shape: List[int] = None):
self._check_model(model)
# TODO: what type is self.shape? pick one.
self.shape = shape if shape else self._infer_shape(model.graph)
self.tensors = self._extract_tensors_from_graph(model.graph)
self.dfg = DFG(model.graph, self.tensors)
......@@ -19,8 +27,9 @@ class GraphBuilder(object):
################################################
@staticmethod
def _check_model(onnx_model):
def _check_model(onnx_model: ModelT):
import warnings
from onnx import checker, onnx_cpp2py_export
if hasattr(checker, "check_model"):
......@@ -31,7 +40,7 @@ class GraphBuilder(object):
warnings.warn(str(e))
@staticmethod
def _infer_shape(onnx_graph):
def _infer_shape(onnx_graph: GraphT) -> Dict[str, List[int]]:
shape = {}
for input in onnx_graph.input:
# get type of input tensor
......@@ -42,7 +51,7 @@ class GraphBuilder(object):
return shape
@staticmethod
def _extract_tensors_from_graph(onnx_graph):
def _extract_tensors_from_graph(onnx_graph: GraphT) -> Dict[str, Tensor]:
tensors = {}
# parse weight
weight_cnt = 0
......@@ -69,12 +78,7 @@ class GraphBuilder(object):
tensors[i] = InputTensor(i)
return tensors
################################################
# Top level Graph Building functions
# return the compilation-ready graph
################################################
def dump_weights(self, output_dir: PathLike):
def dump_weights(self, output_dir: PathLike) -> None:
output_dir = Path(output_dir)
for tensor in self.tensors.values():
if not isinstance(tensor, WeightTensor):
......@@ -83,22 +87,26 @@ class GraphBuilder(object):
class DFG(object):
def __init__(self, graph, tensors):
def __init__(self, graph: GraphT, tensors: Dict[str, Tensor]):
if len(graph.output) > 1:
raise ValueError("Only single-output graph is supported")
self.inputs = graph.input
self.output = graph.output[0]
self.inputs: List[str] = graph.input
self.output: str = graph.output[0]
self._onnx_defs, self._onnx_uses = self.def_use(graph.node)
self._var_count = 0
self.tensors = tensors
self.graph = self.build_dfg(graph)
@property
def traverse_order(self):
def traverse_order(self) -> List[g.DFGNode]:
return list(nx.topological_sort(self.graph))
@staticmethod
def def_use(nodes):
def def_use(nodes: list) -> Tuple[dict, dict]:
"""Computes def/use relation from a list of node.
This method is duck-typed and operates on any node defining .input and .output.
"""
defs, uses = {}, defaultdict(list)
for n in nodes:
for input_ in n.input:
......@@ -107,19 +115,21 @@ class DFG(object):
defs[output] = n
return defs, uses
def _allocate_insert_var(self, node1, node2, input_pos: int = 0):
def _allocate_insert_var(
self, node1: g.DFGNode, node2: g.DFGNode, input_pos: int = 0
) -> None:
varname = f"conv_{self._var_count}"
node1.output = [varname]
node2.input[input_pos] = varname
self._var_count += 1
def detect_flatten(self, graph):
def detect_flatten(self, graph: GraphT) -> Tuple[Dict[str, NodeT], List[g.DFGNode]]:
# Look for a shape-gather-unsqueeze-concat chain
nodes = graph.node
included_nodes = {} # Name to node
generated_nodes = []
def get_next_in_chain(type_: str, node) -> str:
def get_next_in_chain(type_: str, node: Optional[NodeT]) -> Optional[NodeT]:
"""
Get a unique user node of the unique output of Node `node`,
and return it if it has Type `type_`.
......@@ -164,7 +174,7 @@ class DFG(object):
generated_nodes.append(g.FlattenNode.from_onnx_idiom(nodes))
return included_nodes, generated_nodes
def build_dfg(self, graph) -> nx.DiGraph:
def build_dfg(self, graph: GraphT) -> nx.DiGraph:
error_nodes, generated_nodes = [], []
used_onnx_nodes, flatten_nodes = self.detect_flatten(graph)
generated_nodes.extend(flatten_nodes)
......@@ -182,19 +192,19 @@ class DFG(object):
raise ValueError(f"Unsupported operators (first 10): {error_repr[:10]}")
else:
raise ValueError(f"Unsupported operators: {error_repr}")
graph = nx.DiGraph()
ret_graph = nx.DiGraph()
defs, uses = self.def_use(generated_nodes)
graph.add_nodes_from(generated_nodes)
ret_graph.add_nodes_from(generated_nodes)
for onnx_value_name, use_nodes in uses.items():
if onnx_value_name not in defs:
continue
def_node = defs[onnx_value_name]
for use_node in use_nodes:
graph.add_edge(def_node, use_node)
return graph
ret_graph.add_edge(def_node, use_node)
return ret_graph
# This should be the place where partial evaluation happens
def emit_node(self, onnx_node):
def emit_node(self, onnx_node: NodeT) -> Optional[List[g.DFGNode]]:
if onnx_node.op_type == "Conv":
weight_tensor = self.tensors[onnx_node.input[1]]
assert isinstance(weight_tensor, WeightTensor)
......
from os import PathLike
from pathlib import Path
from typing import Dict, List, Tuple, Union
import jinja2
......@@ -12,22 +14,25 @@ template = template_env.get_template(TEMPLATE_FILE)
class HpvmCodeGen:
def __init__(self, DFG: DFG, output_dir: PathLike):
self.dfg = DFG
self.tensors = DFG.tensors
def __init__(self, dfg: DFG, output_dir: PathLike):
self.dfg = dfg
self.tensors = dfg.tensors
self.var_count = 0
self.output_dir = output_dir
self.output_dir = Path(output_dir)
# self.variables is a "onnx name to our name" map
# Each value is (varname, bool) and the bool indicates
# "is root node input" or not.
self.variables = self._get_root_args(DFG.inputs, DFG.tensors)
IdenT = Union[str, int]
self.variables: Dict[str, Tuple[IdenT, bool]] = self._get_root_args(
dfg.inputs, dfg.tensors
)
################################################
# Aux functions
################################################
@staticmethod
def _get_root_args(input_nodes, tensors):
def _get_root_args(input_nodes, tensors) -> Dict[str, Tuple[int, bool]]:
# Input to the graph + all weight tensors
# Sometimes these 2 kinds can overlap (due to ONNX optim)
# We'll dedup this array as well.
......@@ -40,31 +45,16 @@ class HpvmCodeGen:
root_args = sorted(list(set(root_args)))
return {f_name: (index, True) for index, f_name in enumerate(root_args)}
def _allocate_varname(self):
def _allocate_varname(self) -> str:
varname = f"var_{self.var_count}"
self.var_count += 1
return varname
def get_varname_of(self, onnx_var_name):
if onnx_var_name in self.root_args:
return True, self.root_args[onnx_var_name]
elif onnx_var_name in self.local_vars:
return False, self.local_vars[onnx_var_name]
else:
raise KeyError(onnx_var_name)
@staticmethod
def transform_name(name: str):
name = name.replace(".", "_")
if name[0].isnumeric():
name = "_" + name
return name
################################################
# CodeGen functions
################################################
def _emit_hpvm_node_edges(self, input_vars):
def _emit_hpvm_node_edges(self, input_vars: List[str]) -> List[dict]:
ret = []
it = 0
for onnx_var_name in input_vars:
......@@ -81,7 +71,7 @@ class HpvmCodeGen:
it += 1
return ret
def emit_hpvm_node_structures(self):
def emit_hpvm_node_structures(self) -> List[dict]:
node_envs = []
for node in self.dfg.traverse_order:
generated_code = node.hpvm_codegen(self.tensors)
......@@ -103,26 +93,26 @@ class HpvmCodeGen:
)
return node_envs
def emit_root_io(self):
def emit_root_io(self) -> Tuple[List[str], str]:
input_args = [
self.transform_name(name)
make_c_identifier(name)
for name, (_, is_root) in self.variables.items()
if is_root
]
output_arg = self.variables[self.dfg.output.name][0]
return input_args, output_arg
def emit_weights(self):
def emit_weights(self) -> List[dict]:
ret = []
for name, tensor in self.tensors.items():
if not isinstance(tensor, WeightTensor):
continue
name = self.transform_name(name)
name = make_c_identifier(name)
file_path = f"{tensor.get_mapped_name()}_path.bin"
ret.append({"name": name, "shape": tensor.shape, "filename": file_path})
return ret
def compile(self):
def compile(self) -> None:
nodes = self.emit_hpvm_node_structures()
inputs, output = self.emit_root_io()
weights = self.emit_weights()
......@@ -136,3 +126,10 @@ class HpvmCodeGen:
output_dir=self.output_dir,
)
)
def make_c_identifier(name: str) -> str:
name = name.replace(".", "_")
if name[0].isnumeric():
name = "_" + name
return name
from pathlib import Path
from typing import Iterable, Optional
from typing import List, Optional
import onnx
......@@ -24,7 +25,7 @@ def check_version(model, new_version):
def compile(
model,
input_size: Iterable[int],
input_size: Optional[List[int]],
output_dir: Path,
opset_version: Optional[int],
hpvmc: bool,
......@@ -35,7 +36,7 @@ def compile(
if opset_version is not None:
model = check_version(model, opset_version)
graphBuilder = GraphBuilder(model, output_dir)
graphBuilder = GraphBuilder(model)
if hpvmc:
hpvmCodeGen = HpvmCodeGen(graphBuilder.dfg, output_dir)
hpvmCodeGen.compile()
......@@ -54,7 +55,6 @@ def parse_args():
"-s",
"--input-size",
type=int,
required=True,
nargs="+",
help="""Size of input tensor to the model.
Usually 4 dim, including batch size.
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment