From d5f71a5bcaffd64a6356f7d93904f275e742f2a3 Mon Sep 17 00:00:00 2001 From: Yifan Zhao <yifanz16@illinois.edu> Date: Sun, 6 Dec 2020 03:32:13 -0600 Subject: [PATCH] Some comments --- hpvm/projects/onnx/frontend/graph_builder.py | 42 ++++++++++++++-- hpvm/projects/onnx/frontend/graph_ir.py | 52 ++++++++++++++------ hpvm/projects/onnx/frontend/main.py | 2 +- 3 files changed, 76 insertions(+), 20 deletions(-) diff --git a/hpvm/projects/onnx/frontend/graph_builder.py b/hpvm/projects/onnx/frontend/graph_builder.py index e6467bc158..3996b947d2 100644 --- a/hpvm/projects/onnx/frontend/graph_builder.py +++ b/hpvm/projects/onnx/frontend/graph_builder.py @@ -1,5 +1,4 @@ -from collections import defaultdict, namedtuple -from onnx_attr import node_attr_to_dict +from collections import defaultdict from os import PathLike from pathlib import Path from typing import Dict, Iterable, List, Optional, Tuple, Union @@ -8,6 +7,7 @@ import networkx as nx import onnx import graph_ir as g +from onnx_attr import node_attr_to_dict GraphT = onnx.GraphProto NodeT = onnx.NodeProto @@ -16,6 +16,15 @@ NodeT.__repr__ = NodeT.__str__ = lambda self: self.name class MarkedSubGraph: + """A subgraph with information on how it should replace a node in a super graph. + + subgraph: a nx.DiGraph subgraph + entry_edges: a list of edges from nodes "outside" to nodes in self.subgraph + exit: the exit node of the subgraph. + When this subgraph replaces a node `n`, self.exit will be connected to + whateven `n` is connected to. + """ + def __init__(self, subgraph: nx.DiGraph, entry_edges, exit) -> None: assert all(to in subgraph for _, to, _ in entry_edges) assert exit in subgraph @@ -24,6 +33,9 @@ class MarkedSubGraph: @classmethod def idiomatic_1to2(cls, node1, node2, predecessors): + """Create an idiomatic replacement as follow: + + node(arg1, arg2, arg3) -> node2(node1(arg1, arg2), arg3)""" p0, p1, p2 = predecessors graph = nx.DiGraph() graph.add_edge(node1, node2, index=0) @@ -34,24 +46,36 @@ EmitNodeT = Union[MarkedSubGraph, g.DFGNode] class DFG(object): + """ONNX model translated into DFG with `DFGNode`s. + + This class has a DFG, input/output information, and a clear traverse order + (think dominant tree), and is easier for CodeGen classes to work with.""" + def __init__(self, graph: GraphT): self._check_model(graph) self._var_count = 0 + # Build explicit DFG with ONNX nodes onnx_graph = self._build_onnx_dfg(graph) + # Convert ONNX dfg into DFGNode DFG self.graph = self._build_dfg(onnx_graph) + # Find out input nodes and output node (unique) + # removing dead nodes along the way if any self.inputs, self.output = self._dce_get_io_info() ################ Interfaces: @property def traverse_order(self) -> List[g.DFGNode]: + """Get topological order of computational graph by use-def relation.""" return list(nx.topological_sort(self.graph)) - def node_args(self, node): + def node_args(self, node: g.DFGNode): + """Get input arguments of node.""" sorted_edges = sorted(self.graph.in_edges(node, "index"), key=lambda p: p[2]) return [e[0] for e in sorted_edges] def dump_weights(self, output_dir: PathLike) -> None: + """Dump `WeightTensor`s into output_dir.""" output_dir = Path(output_dir) for node in self.graph.nodes: if not isinstance(node, g.WeightTensor): @@ -62,6 +86,8 @@ class DFG(object): @staticmethod def _check_model(onnx_graph: GraphT): + """Check model validaty and single output (which is our limitation)""" + import warnings from onnx import checker, onnx_cpp2py_export @@ -93,7 +119,15 @@ class DFG(object): return ret_graph def _build_dfg(self, onnx_graph: nx.DiGraph) -> nx.DiGraph: + """Translate _build_onnx_dfg output into DFGNode DFG. + + First run some passes to process subgraphs that needs to be + processed together, then each unprocessed node is generated into + 1 or more nodes.""" + + # Remove subgraphs that can be a single Flatten instead onnx_graph = detect_flatten(onnx_graph) + # Remove subgraphs that look like padding but does nothing onnx_graph = remove_no_padding(onnx_graph) # For each onnx node, generate our nodes node_to_nodes, error_nodes = {}, [] @@ -109,6 +143,7 @@ class DFG(object): raise ValueError(f"Unsupported operators (first 10): {error_repr[:10]}") else: raise ValueError(f"Unsupported operators: {error_repr}") + # Apply node_to_nodes replacement on onnx_graph to create a new DFG return build_graph_with_mapping(onnx_graph, node_to_nodes) def _dce_get_io_info(self): @@ -187,6 +222,7 @@ def def_use(nodes: Iterable) -> Tuple[dict, dict]: def remove_no_padding(graph: nx.DiGraph) -> nx.DiGraph: + """Remove subgraphs that look like padding but does nothing.""" for node in list(graph.nodes): if node.op_type != "Pad": continue diff --git a/hpvm/projects/onnx/frontend/graph_ir.py b/hpvm/projects/onnx/frontend/graph_ir.py index 9ffff9944c..fd42617f60 100644 --- a/hpvm/projects/onnx/frontend/graph_ir.py +++ b/hpvm/projects/onnx/frontend/graph_ir.py @@ -1,11 +1,6 @@ -################################################ -# Top Level DFGNode interface -################################################ - - import abc from os import PathLike -from typing import List +from typing import List, Tuple import onnx @@ -13,27 +8,32 @@ from onnx_attr import node_attr_to_dict class DFGNode(abc.ABC): + """Abstract node that represents 1 instruction in HPVM. + + op_type should be overriden in subclasses for readability. + """ + op_type = "" def __init__(self, name: str, attrs: dict = {}): self.name, self.attrs = name, attrs - def codegen(self): + def codegen(self) -> Tuple[str, list]: return "", [] - def hpvm_codegen(self): + def hpvm_codegen(self) -> Tuple[str, list]: return "", [] - def __repr__(self): + def __repr__(self) -> str: return f"{self.op_type}({self.name})" -################################################ -# Actual Implementation of Operators -################################################ +class TensorNode(DFGNode, abc.ABC): + """An abstract node for a value that exists without an instruction. + This is akin to Value class in LLVM, but in a different place on the + inheritance tree.""" -class TensorNode(DFGNode): def __init__(self, proto: onnx.TensorProto, new_name: str): if not proto.name.strip(): raise ValueError("Tensor's name is required.") @@ -47,6 +47,12 @@ class TensorNode(DFGNode): class InputTensor(TensorNode): + """Input to the computation graph. + + This is basically only used for its information about the ONNX input, + itself doesn't emit instruction or any interesting thing. + """ + op_type = "InputTensor" def __init__(self, input_proto: onnx.TensorProto, new_name: str): @@ -59,6 +65,12 @@ class InputTensor(TensorNode): class WeightTensor(TensorNode): + """An initialized parameter in ONNX graph. + + This is any parameter that has a initializer value in the ONNX model + (as opposed to InputTensor, which doesn't have any value). + """ + op_type = "WeightTensor" def __init__(self, weight_proto: onnx.TensorProto, new_name: str): @@ -113,6 +125,8 @@ class Conv2DNode(DFGNode): class _Pool2DNode(DFGNode, abc.ABC): + """Common super class of Average pooling and Max pooling.""" + pool_type = "0" def __init__(self, name: str, attrs: dict): @@ -170,7 +184,13 @@ class MatMulNode(DFGNode): return "__visc__tensor_mul", [] @staticmethod - def gemm_transpose(onnx_node, predec): + def gemm_transpose(onnx_gemm_node, predec): + """Find and transpose weights of the onnx gemm node. + + This way we transpose the constant weight instead of exporting + a transpose node (which doesn't yet exist in HPVM). + """ + def _transpose(weight): if not isinstance(weight, WeightTensor): raise ValueError( @@ -179,7 +199,7 @@ class MatMulNode(DFGNode): weight.transpose_() # Some tensors may need transposing - attrs = node_attr_to_dict(onnx_node) + attrs = node_attr_to_dict(onnx_gemm_node) if attrs.get("transA", False): _transpose(predec[0]) if attrs.get("transB", False): @@ -257,7 +277,7 @@ class ActivationNode(DFGNode): class LogicalOpNode(DFGNode): """ - ELement wise operators that is not for activation function. + Element wise operators that is not for activation function. In other words, they are logical comparison operators e.g. And, Equal, Greater, GreaterOrEqual, Less, LessOrEqual, Or, Xor diff --git a/hpvm/projects/onnx/frontend/main.py b/hpvm/projects/onnx/frontend/main.py index 25e76ae8fa..d3037224ac 100644 --- a/hpvm/projects/onnx/frontend/main.py +++ b/hpvm/projects/onnx/frontend/main.py @@ -90,7 +90,7 @@ hpvmc: HPVM C Interface. Default value is hpvmc.""", args = parser.parse_args() args.hpvmc = args.compile_mode == "hpvmc" - delattr(args, 'compile_mode') + delattr(args, "compile_mode") return args -- GitLab