from collections import defaultdict from os import PathLike from pathlib import Path from typing import Dict, Iterable, List, Optional, Tuple, Union import networkx as nx import onnx from . import graph_ir as g from .onnx_attr import node_attr_to_dict GraphT = onnx.GraphProto NodeT = onnx.NodeProto NodeT.__hash__ = lambda self: id(self) NodeT.__repr__ = NodeT.__str__ = lambda self: self.name class MarkedSubGraph: """A subgraph with information on how it should replace a node in a super graph. subgraph: a nx.DiGraph subgraph entry_edges: a list of edges from nodes "outside" to nodes in self.subgraph exit: the exit node of the subgraph. When this subgraph replaces a node `n`, self.exit will be connected to whateven `n` is connected to. """ def __init__(self, subgraph: nx.DiGraph, entry_edges, exit) -> None: assert all(to in subgraph for _, to, _ in entry_edges) assert exit in subgraph self.subgraph, self.exit = subgraph, exit self.entry_edges = [(f, t, {"index": i}) for f, t, i in entry_edges] @classmethod def idiomatic_1to2(cls, node1, node2, predecessors): """Create an idiomatic replacement as follow: node(arg1, arg2, arg3) -> node2(node1(arg1, arg2), arg3)""" p0, p1, p2 = predecessors graph = nx.DiGraph() graph.add_edge(node1, node2, index=0) return cls(graph, [(p0, node1, 0), (p1, node1, 1), (p2, node2, 1)], node2) EmitNodeT = Union[MarkedSubGraph, g.DFGNode] class DFG(object): """ONNX model translated into DFG with `DFGNode`s. This class has a DFG, input/output information, and a clear traverse order (think dominant tree), and is easier for CodeGen classes to work with.""" def __init__(self, graph: GraphT): self._check_model(graph) self._var_count = 0 # Build explicit DFG with ONNX nodes onnx_graph = self._build_onnx_dfg(graph) # Convert ONNX dfg into DFGNode DFG self.graph = self._build_dfg(onnx_graph) # Find out input nodes and output node (unique) # removing dead nodes along the way if any self.inputs, self.output = self._dce_get_io_info() ################ Interfaces: @property def traverse_order(self) -> List[g.DFGNode]: """Get topological order of computational graph by use-def relation.""" return list(nx.topological_sort(self.graph)) def node_args(self, node: g.DFGNode): """Get input arguments of node.""" sorted_edges = sorted(self.graph.in_edges(node, "index"), key=lambda p: p[2]) return [e[0] for e in sorted_edges] def dump_weights(self, output_dir: PathLike) -> None: """Dump `WeightTensor`s into output_dir.""" output_dir = Path(output_dir) for node in self.graph.nodes: if not isinstance(node, g.WeightTensor): continue node.dump_weight(output_dir / (node.new_name + "_path.bin")) ################ Internal methods (high-level): @staticmethod def _check_model(onnx_graph: GraphT): """Check model validaty and single output (which is our limitation)""" import warnings from onnx import checker, onnx_cpp2py_export # try use onnx's own model checker before converting any model try: checker.check_graph(onnx_graph) except onnx_cpp2py_export.checker.ValidationError as e: warnings.warn(str(e)) if any(len(n.output) > 1 for n in onnx_graph.node): raise ValueError("All node must have single output") if len(onnx_graph.output) > 1: raise ValueError("Graph must have single output") @staticmethod def _build_onnx_dfg(graph: GraphT) -> nx.DiGraph: """Creates a DiGraph (by use-def relation) of onnx nodes from onnx GraphProto. DiGraph is easier to use as a graph compared to GraphProto where use-def is implicit.""" ret_graph = nx.DiGraph() onnx_defs, onnx_uses = def_use(graph.node) tensors = extract_tensors_from_graph(graph) ret_graph.add_nodes_from(graph.node) for onnx_value_name, use_nodes in onnx_uses.items(): def_node = onnx_defs.get(onnx_value_name) if def_node is None: def_node = tensors[onnx_value_name] for use_node, used_at_narg in use_nodes: ret_graph.add_edge(def_node, use_node, index=used_at_narg) return ret_graph def _build_dfg(self, onnx_graph: nx.DiGraph) -> nx.DiGraph: """Translate _build_onnx_dfg output into DFGNode DFG. First run some passes to process subgraphs that needs to be processed together, then each unprocessed node is generated into 1 or more nodes.""" # Remove subgraphs that can be a single Flatten instead onnx_graph = detect_flatten(onnx_graph) # Remove subgraphs that look like padding but does nothing onnx_graph = remove_no_padding(onnx_graph) # For each onnx node, generate our nodes node_to_nodes, error_nodes = {}, [] for onnx_node in nx.topological_sort(onnx_graph): our_nodes = self._emit_node(onnx_graph, onnx_node) if our_nodes is None: error_nodes.append(onnx_node) else: node_to_nodes[onnx_node] = our_nodes if error_nodes: error_repr = [f"{n.name}({n.op_type})" for n in error_nodes] if len(error_nodes) > 10: # Magic number raise ValueError(f"Unsupported operators (first 10): {error_repr[:10]}") else: raise ValueError(f"Unsupported operators: {error_repr}") # Apply node_to_nodes replacement on onnx_graph to create a new DFG return build_graph_with_mapping(onnx_graph, node_to_nodes) def _dce_get_io_info(self): inputs = [n for n in self.graph if isinstance(n, g.InputTensor)] inputs_set = set(inputs) reachables = set() for component in nx.connected_components(self.graph.to_undirected()): # If any inputs goes into this subgraph, it's alive. if set(component).intersection(inputs_set): reachables.update(component) unreachables = set(self.graph) - reachables # Remove nodes unreachable from input self.graph.remove_nodes_from(unreachables) # Then outputs are nodes with out_degree = 0 outputs = [n for n in self.graph if self.graph.out_degree[n] == 0] assert len(outputs) == 1 return inputs, outputs[0] @staticmethod def _emit_node(in_graph: nx.DiGraph, node: NodeT) -> Optional[EmitNodeT]: predec = sorted_inputs(in_graph, node) if isinstance(node, g.DFGNode): # Directly add node into return graph. return node attrs = node_attr_to_dict(node) if node.op_type == "Conv": weight_tensor = predec[1] assert isinstance(weight_tensor, g.WeightTensor) if len(weight_tensor.shape) != 4: return None # Only supports 2D conv conv_node = g.Conv2DNode(node.name, attrs) if len(predec) == 2: return conv_node # Split into conv followed by an addition bias_node = g.BiasAddNode(f"Bias_{node.name.split('_')[-1]}") return MarkedSubGraph.idiomatic_1to2(conv_node, bias_node, predec) elif node.op_type in ("MatMul", "Gemm"): mul_node = g.MatMulNode(node.name, attrs) if node.op_type == "Gemm": mul_node.gemm_transpose(node, predec) if len(predec) == 2: return mul_node # Split into mul followed by an addition bias_node = g.BiasAddNode(f"Bias_{node.name.split('_')[-1]}") return MarkedSubGraph.idiomatic_1to2(mul_node, bias_node, predec) one_to_one_nodes = { "MaxPool": g.MaxPool2DNode, "AveragePool": g.AveragePool2DNode, "Add": g.AddNode, "Softmax": g.SoftMaxNode, "Relu": g.ReluNode, "Tanh": g.TanhNode, "BatchNormalization": g.BatchNormalizationNode, "Pad": g.PadNode, "Identity": g.IdentityNode, "Flatten": g.FlattenNode, } if node.op_type in one_to_one_nodes: return one_to_one_nodes[node.op_type](node.name, attrs) return None def def_use(nodes: Iterable) -> Tuple[dict, dict]: """Computes def/use relation from a list of node. This method is duck-typed and operates on any node defining .input and .output. """ defs, uses = {}, defaultdict(list) for n in nodes: for i, input_ in enumerate(n.input): uses[input_].append((n, i)) for output in n.output: defs[output] = n return defs, uses def remove_no_padding(graph: nx.DiGraph) -> nx.DiGraph: """Remove subgraphs that look like padding but does nothing.""" for node in list(graph.nodes): if node.op_type != "Pad": continue input_args = sorted_inputs(graph, node) # Find the second input argument to Pad (will be a Constant node) # and take that away as well. nct = input_args[1] padding = node_attr_to_dict(nct)["value"] if any(p != 0 for p in padding): continue # Connect input of Pad to where output of Pad goes succ = graph.out_edges(node, "index") for _, to, index in succ: graph.add_edge(input_args[0], to, index=index) # Remove nodes graph.remove_nodes_from([node, nct]) return graph def detect_flatten(graph: nx.DiGraph) -> nx.DiGraph: """Look for a shape-gather-unsqueeze-concat-reshape chain and replace that with flatten.""" for node in list(graph.nodes): if node.op_type != "Shape": continue ng = get_next_in_chain(graph, "Gather", node) # Find the second input argument to Gather (will be a Constant node) # and take that away as well. nct = sorted_inputs(graph, ng)[1] nu = get_next_in_chain(graph, "Unsqueeze", ng) nc = get_next_in_chain(graph, "Concat", nu) nr = get_next_in_chain(graph, "Reshape", nc) if nr is None: continue _, suffix = node.name.split("_") gen_node = g.FlattenNode(f"Flatten_{suffix}") replace_chain_with_node_(graph, [node, ng, nct, nu, nc, nr], gen_node) return graph def get_next_in_chain( graph: nx.DiGraph, type_: str, node: Optional[NodeT] ) -> Optional[NodeT]: """ Get a unique user node of the unique output of Node `node`, and return it if it has Type `type_`. """ if node is None or len(node.output) != 1: return None # Propagates None; Unique output users = list(graph.neighbors(node)) if len(users) != 1 or users[0].op_type != type_: return None # Unique user of the output; Correct type return users[0] def replace_chain_with_node_(graph: nx.DiGraph, chain: list, node) -> nx.DiGraph: inputs = sorted_inputs(graph, chain[0]) succ = graph.out_edges(chain[-1], "index") for i, n in enumerate(inputs): graph.add_edge(n, node, index=i) for _, to, index in succ: graph.add_edge(node, to, index=index) graph.remove_nodes_from(chain) return graph def build_graph_with_mapping( graph: nx.DiGraph, node_mapping: Dict[NodeT, EmitNodeT] ) -> nx.DiGraph: graph = graph.copy() single_node, multi_node = {}, {} for replace_node, by_node in node_mapping.items(): if isinstance(by_node, g.DFGNode): single_node[replace_node] = by_node else: multi_node[replace_node] = by_node # We do one-to-many replacements first # because their predecessors are specified as onnx nodes. for replace_node, subgraph in multi_node.items(): # Add subgraph itself graph = nx.compose(graph, subgraph.subgraph) # Add in edges graph.add_edges_from(subgraph.entry_edges) # Add out edges succ = graph.out_edges(replace_node, "index") for _, to, index in succ: graph.add_edge(subgraph.exit, to, index=index) # Remove old node graph.remove_node(replace_node) # Then do all one-to-one replacements. graph = nx.relabel_nodes(graph, single_node) return graph def extract_tensors_from_graph(onnx_graph: GraphT) -> Dict[str, g.TensorNode]: tensors = {} # parse weight weight_cnt = 0 for weight_tensor in onnx_graph.initializer: tensors[weight_tensor.name] = g.WeightTensor( weight_tensor, f"weight_{weight_cnt}" ) weight_cnt += 1 # parse input input_cnt = 0 for input_ in onnx_graph.input: if input_.name in tensors: continue tensors[input_.name] = g.InputTensor(input_, f"input_{input_cnt}") input_cnt += 1 return tensors def sorted_inputs(graph: nx.DiGraph, node): sorted_edges = sorted(graph.in_edges(node, "index"), key=lambda p: p[2]) return [e[0] for e in sorted_edges] def draw_graph(graph: nx.DiGraph, output_to): from networkx.drawing.nx_agraph import to_agraph agraph = to_agraph(graph) agraph.layout("dot") agraph.draw(output_to)