diff --git a/hpvm/projects/onnx/frontend/graph_builder.py b/hpvm/projects/onnx/frontend/graph_builder.py index db7cf208e4516ca2d3f221cb6effa31c5641cde6..f7260f2b2c1edfc1f88271e99386a1db6d20319d 100644 --- a/hpvm/projects/onnx/frontend/graph_builder.py +++ b/hpvm/projects/onnx/frontend/graph_builder.py @@ -1,15 +1,23 @@ +from collections import defaultdict from os import PathLike from pathlib import Path +from typing import Dict, List, Optional, Tuple import networkx as nx +import onnx + import graph_ir as g -from tensor import InputTensor, WeightTensor -from collections import defaultdict +from tensor import InputTensor, Tensor, WeightTensor +ModelT = onnx.ModelProto +GraphT = onnx.GraphProto +NodeT = onnx.NodeProto -class GraphBuilder(object): - def __init__(self, model, shape): + +class GraphBuilder: + def __init__(self, model: ModelT, shape: List[int] = None): self._check_model(model) + # TODO: what type is self.shape? pick one. self.shape = shape if shape else self._infer_shape(model.graph) self.tensors = self._extract_tensors_from_graph(model.graph) self.dfg = DFG(model.graph, self.tensors) @@ -19,8 +27,9 @@ class GraphBuilder(object): ################################################ @staticmethod - def _check_model(onnx_model): + def _check_model(onnx_model: ModelT): import warnings + from onnx import checker, onnx_cpp2py_export if hasattr(checker, "check_model"): @@ -31,7 +40,7 @@ class GraphBuilder(object): warnings.warn(str(e)) @staticmethod - def _infer_shape(onnx_graph): + def _infer_shape(onnx_graph: GraphT) -> Dict[str, List[int]]: shape = {} for input in onnx_graph.input: # get type of input tensor @@ -42,7 +51,7 @@ class GraphBuilder(object): return shape @staticmethod - def _extract_tensors_from_graph(onnx_graph): + def _extract_tensors_from_graph(onnx_graph: GraphT) -> Dict[str, Tensor]: tensors = {} # parse weight weight_cnt = 0 @@ -69,12 +78,7 @@ class GraphBuilder(object): tensors[i] = InputTensor(i) return tensors - ################################################ - # Top level Graph Building functions - # return the compilation-ready graph - ################################################ - - def dump_weights(self, output_dir: PathLike): + def dump_weights(self, output_dir: PathLike) -> None: output_dir = Path(output_dir) for tensor in self.tensors.values(): if not isinstance(tensor, WeightTensor): @@ -83,22 +87,26 @@ class GraphBuilder(object): class DFG(object): - def __init__(self, graph, tensors): + def __init__(self, graph: GraphT, tensors: Dict[str, Tensor]): if len(graph.output) > 1: raise ValueError("Only single-output graph is supported") - self.inputs = graph.input - self.output = graph.output[0] + self.inputs: List[str] = graph.input + self.output: str = graph.output[0] self._onnx_defs, self._onnx_uses = self.def_use(graph.node) self._var_count = 0 self.tensors = tensors self.graph = self.build_dfg(graph) @property - def traverse_order(self): + def traverse_order(self) -> List[g.DFGNode]: return list(nx.topological_sort(self.graph)) @staticmethod - def def_use(nodes): + def def_use(nodes: list) -> Tuple[dict, dict]: + """Computes def/use relation from a list of node. + + This method is duck-typed and operates on any node defining .input and .output. + """ defs, uses = {}, defaultdict(list) for n in nodes: for input_ in n.input: @@ -107,19 +115,21 @@ class DFG(object): defs[output] = n return defs, uses - def _allocate_insert_var(self, node1, node2, input_pos: int = 0): + def _allocate_insert_var( + self, node1: g.DFGNode, node2: g.DFGNode, input_pos: int = 0 + ) -> None: varname = f"conv_{self._var_count}" node1.output = [varname] node2.input[input_pos] = varname self._var_count += 1 - def detect_flatten(self, graph): + def detect_flatten(self, graph: GraphT) -> Tuple[Dict[str, NodeT], List[g.DFGNode]]: # Look for a shape-gather-unsqueeze-concat chain nodes = graph.node included_nodes = {} # Name to node generated_nodes = [] - def get_next_in_chain(type_: str, node) -> str: + def get_next_in_chain(type_: str, node: Optional[NodeT]) -> Optional[NodeT]: """ Get a unique user node of the unique output of Node `node`, and return it if it has Type `type_`. @@ -164,7 +174,7 @@ class DFG(object): generated_nodes.append(g.FlattenNode.from_onnx_idiom(nodes)) return included_nodes, generated_nodes - def build_dfg(self, graph) -> nx.DiGraph: + def build_dfg(self, graph: GraphT) -> nx.DiGraph: error_nodes, generated_nodes = [], [] used_onnx_nodes, flatten_nodes = self.detect_flatten(graph) generated_nodes.extend(flatten_nodes) @@ -182,19 +192,19 @@ class DFG(object): raise ValueError(f"Unsupported operators (first 10): {error_repr[:10]}") else: raise ValueError(f"Unsupported operators: {error_repr}") - graph = nx.DiGraph() + ret_graph = nx.DiGraph() defs, uses = self.def_use(generated_nodes) - graph.add_nodes_from(generated_nodes) + ret_graph.add_nodes_from(generated_nodes) for onnx_value_name, use_nodes in uses.items(): if onnx_value_name not in defs: continue def_node = defs[onnx_value_name] for use_node in use_nodes: - graph.add_edge(def_node, use_node) - return graph + ret_graph.add_edge(def_node, use_node) + return ret_graph # This should be the place where partial evaluation happens - def emit_node(self, onnx_node): + def emit_node(self, onnx_node: NodeT) -> Optional[List[g.DFGNode]]: if onnx_node.op_type == "Conv": weight_tensor = self.tensors[onnx_node.input[1]] assert isinstance(weight_tensor, WeightTensor) diff --git a/hpvm/projects/onnx/frontend/hpvm_codegen.py b/hpvm/projects/onnx/frontend/hpvm_codegen.py index b269a7debffb7c3ae4adefb8e4e0caa329a631db..719e51dc1f102c3b134aca3b90709b5a98df9679 100644 --- a/hpvm/projects/onnx/frontend/hpvm_codegen.py +++ b/hpvm/projects/onnx/frontend/hpvm_codegen.py @@ -1,4 +1,6 @@ from os import PathLike +from pathlib import Path +from typing import Dict, List, Tuple, Union import jinja2 @@ -12,22 +14,25 @@ template = template_env.get_template(TEMPLATE_FILE) class HpvmCodeGen: - def __init__(self, DFG: DFG, output_dir: PathLike): - self.dfg = DFG - self.tensors = DFG.tensors + def __init__(self, dfg: DFG, output_dir: PathLike): + self.dfg = dfg + self.tensors = dfg.tensors self.var_count = 0 - self.output_dir = output_dir + self.output_dir = Path(output_dir) # self.variables is a "onnx name to our name" map # Each value is (varname, bool) and the bool indicates # "is root node input" or not. - self.variables = self._get_root_args(DFG.inputs, DFG.tensors) + IdenT = Union[str, int] + self.variables: Dict[str, Tuple[IdenT, bool]] = self._get_root_args( + dfg.inputs, dfg.tensors + ) ################################################ # Aux functions ################################################ @staticmethod - def _get_root_args(input_nodes, tensors): + def _get_root_args(input_nodes, tensors) -> Dict[str, Tuple[int, bool]]: # Input to the graph + all weight tensors # Sometimes these 2 kinds can overlap (due to ONNX optim) # We'll dedup this array as well. @@ -40,31 +45,16 @@ class HpvmCodeGen: root_args = sorted(list(set(root_args))) return {f_name: (index, True) for index, f_name in enumerate(root_args)} - def _allocate_varname(self): + def _allocate_varname(self) -> str: varname = f"var_{self.var_count}" self.var_count += 1 return varname - def get_varname_of(self, onnx_var_name): - if onnx_var_name in self.root_args: - return True, self.root_args[onnx_var_name] - elif onnx_var_name in self.local_vars: - return False, self.local_vars[onnx_var_name] - else: - raise KeyError(onnx_var_name) - - @staticmethod - def transform_name(name: str): - name = name.replace(".", "_") - if name[0].isnumeric(): - name = "_" + name - return name - ################################################ # CodeGen functions ################################################ - def _emit_hpvm_node_edges(self, input_vars): + def _emit_hpvm_node_edges(self, input_vars: List[str]) -> List[dict]: ret = [] it = 0 for onnx_var_name in input_vars: @@ -81,7 +71,7 @@ class HpvmCodeGen: it += 1 return ret - def emit_hpvm_node_structures(self): + def emit_hpvm_node_structures(self) -> List[dict]: node_envs = [] for node in self.dfg.traverse_order: generated_code = node.hpvm_codegen(self.tensors) @@ -103,26 +93,26 @@ class HpvmCodeGen: ) return node_envs - def emit_root_io(self): + def emit_root_io(self) -> Tuple[List[str], str]: input_args = [ - self.transform_name(name) + make_c_identifier(name) for name, (_, is_root) in self.variables.items() if is_root ] output_arg = self.variables[self.dfg.output.name][0] return input_args, output_arg - def emit_weights(self): + def emit_weights(self) -> List[dict]: ret = [] for name, tensor in self.tensors.items(): if not isinstance(tensor, WeightTensor): continue - name = self.transform_name(name) + name = make_c_identifier(name) file_path = f"{tensor.get_mapped_name()}_path.bin" ret.append({"name": name, "shape": tensor.shape, "filename": file_path}) return ret - def compile(self): + def compile(self) -> None: nodes = self.emit_hpvm_node_structures() inputs, output = self.emit_root_io() weights = self.emit_weights() @@ -136,3 +126,10 @@ class HpvmCodeGen: output_dir=self.output_dir, ) ) + + +def make_c_identifier(name: str) -> str: + name = name.replace(".", "_") + if name[0].isnumeric(): + name = "_" + name + return name diff --git a/hpvm/projects/onnx/frontend/main.py b/hpvm/projects/onnx/frontend/main.py index 7e3809d0eb6c635709b8727ac5a07738db3373c2..161b6e8efcac0128142e2475a67fdc81340d05e4 100644 --- a/hpvm/projects/onnx/frontend/main.py +++ b/hpvm/projects/onnx/frontend/main.py @@ -1,5 +1,6 @@ from pathlib import Path -from typing import Iterable, Optional +from typing import List, Optional + import onnx @@ -24,7 +25,7 @@ def check_version(model, new_version): def compile( model, - input_size: Iterable[int], + input_size: Optional[List[int]], output_dir: Path, opset_version: Optional[int], hpvmc: bool, @@ -35,7 +36,7 @@ def compile( if opset_version is not None: model = check_version(model, opset_version) - graphBuilder = GraphBuilder(model, output_dir) + graphBuilder = GraphBuilder(model) if hpvmc: hpvmCodeGen = HpvmCodeGen(graphBuilder.dfg, output_dir) hpvmCodeGen.compile() @@ -54,7 +55,6 @@ def parse_args(): "-s", "--input-size", type=int, - required=True, nargs="+", help="""Size of input tensor to the model. Usually 4 dim, including batch size.