Skip to content
Snippets Groups Projects
Commit 27c49a85 authored by Yifan Zhao's avatar Yifan Zhao
Browse files

Use networkx for onnx graph

parent 5f46b4b1
No related branches found
No related tags found
No related merge requests found
......@@ -13,6 +13,7 @@ from tensor import InputTensor, Tensor, WeightTensor
ModelT = onnx.ModelProto
GraphT = onnx.GraphProto
NodeT = onnx.NodeProto
NodeT.__hash__ = lambda self: id(self)
class GraphBuilder:
......@@ -43,7 +44,9 @@ class GraphBuilder:
# parse weight
weight_cnt = 0
for weight_tensor in onnx_graph.initializer:
tensors[weight_tensor.name] = WeightTensor(weight_tensor, f"weight_{weight_cnt}")
tensors[weight_tensor.name] = WeightTensor(
weight_tensor, f"weight_{weight_cnt}"
)
weight_cnt += 1
# parse input
input_cnt = 0
......@@ -67,12 +70,14 @@ class DFG(object):
if len(graph.output) > 1:
raise ValueError("Only single-output graph is supported")
self.output: str = graph.output[0].name
self._onnx_defs, self._onnx_uses = self.def_use(graph.node)
self._var_count = 0
self.tensors = tensors
self._graph = self._build_dfg(graph)
onnx_graph = self._build_onnx_dfg(graph)
self._graph = self._build_dfg(onnx_graph)
self._dce() # Remove unused values
################ Interfaces:
@property
def traverse_order(self) -> List[g.DFGNode]:
return list(nx.topological_sort(self._graph))
......@@ -89,115 +94,74 @@ class DFG(object):
assert len(inputs) == 1
return inputs[0]
@staticmethod
def def_use(nodes: Iterable) -> Tuple[dict, dict]:
"""Computes def/use relation from a list of node.
################ Internal methods (high-level):
This method is duck-typed and operates on any node defining .input and .output.
"""
defs, uses = {}, defaultdict(list)
for n in nodes:
for input_ in n.input:
uses[input_].append(n)
for output in n.output:
defs[output] = n
return defs, uses
def _build_onnx_dfg(self, graph: GraphT) -> nx.DiGraph:
"""Creates a DiGraph (by use-def relation) of onnx nodes from onnx GraphProto.
DiGraph is easier to use as a graph compared to GraphProto where use-def is implicit."""
ret_graph = nx.DiGraph()
onnx_defs, onnx_uses = self._def_use(graph.node)
ret_graph.add_nodes_from(graph.node)
for onnx_value_name, use_nodes in onnx_uses.items():
if onnx_value_name not in onnx_defs:
continue
def_node = onnx_defs[onnx_value_name]
for use_node in use_nodes:
ret_graph.add_edge(def_node, use_node)
return ret_graph
def _build_dfg(self, onnx_graph: nx.DiGraph) -> nx.DiGraph:
onnx_graph = self._detect_flatten(onnx_graph)
# For each onnx node, generate our nodes
ret_graph = onnx_graph.copy()
error_nodes = []
for onnx_node in onnx_graph.nodes:
if isinstance(onnx_node, g.DFGNode):
continue
our_nodes = self._emit_node(onnx_node)
if our_nodes is None:
error_nodes.append(onnx_node)
else:
replace_node_with_chain_(ret_graph, onnx_node, our_nodes)
if error_nodes:
error_repr = [f"{n.name}({n.op_type})" for n in error_nodes]
if len(error_nodes) > 10: # Magic number
raise ValueError(f"Unsupported operators (first 10): {error_repr[:10]}")
else:
raise ValueError(f"Unsupported operators: {error_repr}")
return ret_graph
def _dce(self):
_, uses = self.def_use(self._graph.nodes)
_, uses = self._def_use(self._graph.nodes)
used_values = set(uses.keys())
unused_values = set(self.tensors.keys()) - used_values
for k in unused_values:
self.tensors.pop(k)
def _split_node_args(
self, node1: g.DFGNode, node2: g.DFGNode, input_pos: int = 0, pop_pos: int = -1
) -> None:
varname = f"conv_{self._var_count}"
node1.input.pop(pop_pos)
node1.output = [varname]
node2.input[input_pos] = varname
self._var_count += 1
def _detect_flatten(self, graph: GraphT) -> Tuple[Dict[str, NodeT], List[g.DFGNode]]:
# Look for a shape-gather-unsqueeze-concat chain
nodes = graph.node
included_nodes = {} # Name to node
generated_nodes = []
def get_next_in_chain(type_: str, node: Optional[NodeT]) -> Optional[NodeT]:
"""
Get a unique user node of the unique output of Node `node`,
and return it if it has Type `type_`.
Also put this node into `included_nodes`.
"""
if node is None:
return None # propagates None
if len(node.output) != 1:
return None # Unique output
users = self._onnx_uses.get(node.output[0], [])
if len(users) != 1:
return None # Unique user of the output
(user,) = users
if user.op_type != type_:
return None # Correct type
if user.name in included_nodes:
# Part of this chain intersects another chain, so we give up
# TODO: in fact we should remove BOTH chain in this case
return None
return user
def _detect_flatten(self, graph: nx.DiGraph) -> nx.DiGraph:
"""Look for a shape-gather-unsqueeze-concat-reshape chain and replace that with flatten."""
def get_def_at_pos(node, pos: int):
return self._onnx_defs[node.input[pos]]
def add_nodes(ns):
for n in ns:
included_nodes[n.name] = n
from_, to = list(graph.in_edges(node))[pos]
return from_
for n in nodes:
if n.op_type != "Shape":
for node in list(graph.nodes):
if node.op_type != "Shape":
continue
ng = get_next_in_chain("Gather", n)
ng = self.get_next_in_chain(graph, "Gather", node)
# Find the second input argument to Gather (will be a Constant node)
# and take that away as well.
nct = get_def_at_pos(ng, 1)
nu = get_next_in_chain("Unsqueeze", ng)
nc = get_next_in_chain("Concat", nu)
nr = get_next_in_chain("Reshape", nc)
if nr is not None:
nodes = [n, ng, nct, nu, nc, nr]
add_nodes(nodes)
generated_nodes.append(g.FlattenNode.from_onnx_idiom(nodes))
return included_nodes, generated_nodes
def _build_dfg(self, graph: GraphT) -> nx.DiGraph:
error_nodes, generated_nodes = [], []
used_onnx_nodes, flatten_nodes = self._detect_flatten(graph)
generated_nodes.extend(flatten_nodes)
for onnx_node in graph.node:
if onnx_node.name in used_onnx_nodes:
continue
our_node = self._emit_node(onnx_node)
if our_node is None:
error_nodes.append(onnx_node)
else:
generated_nodes.extend(our_node)
if error_nodes:
error_repr = [f"{n.name}({n.op_type})" for n in error_nodes]
if len(error_nodes) > 10: # Magic number
raise ValueError(f"Unsupported operators (first 10): {error_repr[:10]}")
else:
raise ValueError(f"Unsupported operators: {error_repr}")
ret_graph = nx.DiGraph()
defs, uses = self.def_use(generated_nodes)
ret_graph.add_nodes_from(generated_nodes)
for onnx_value_name, use_nodes in uses.items():
if onnx_value_name not in defs:
nu = self.get_next_in_chain(graph, "Unsqueeze", ng)
nc = self.get_next_in_chain(graph, "Concat", nu)
nr = self.get_next_in_chain(graph, "Reshape", nc)
if nr is None:
continue
def_node = defs[onnx_value_name]
for use_node in use_nodes:
ret_graph.add_edge(def_node, use_node)
return ret_graph
nodes = [node, ng, nct, nu, nc, nr]
gen_node = g.FlattenNode.from_onnx_idiom(nodes)
graph = replace_graph_with_node_(graph, nodes, gen_node)
return graph
# This should be the place where partial evaluation happens
def _emit_node(self, onnx_node: NodeT) -> Optional[List[g.DFGNode]]:
......@@ -219,9 +183,9 @@ class DFG(object):
# Some tensors may need transposing
attrs = node_attr_to_dict(onnx_node)
# We cannot transpose input tensor (need a transpose op)
assert not attrs.get('transA', False)
assert not attrs.get("transA", False)
# But we can transpose weight tensor before emitting it
if attrs.get('transB', False):
if attrs.get("transB", False):
weight_tensor = self.tensors[onnx_node.input[1]]
assert isinstance(weight_tensor, WeightTensor)
weight_tensor.transpose_()
......@@ -248,3 +212,73 @@ class DFG(object):
if onnx_node.op_type in one_to_one_nodes:
return [one_to_one_nodes[onnx_node.op_type](onnx_node)]
return None
################ Internal methods (utils):
@staticmethod
def get_next_in_chain(
graph: nx.DiGraph, type_: str, node: Optional[NodeT]
) -> Optional[NodeT]:
"""
Get a unique user node of the unique output of Node `node`,
and return it if it has Type `type_`.
"""
if node is None or len(node.output) != 1:
return None # Propagates None; Unique output
users = list(graph.neighbors(node))
if len(users) != 1 or users[0].op_type != type_:
return None # Unique user of the output; Correct type
return users[0]
def _split_node_args(
self, node1: g.DFGNode, node2: g.DFGNode, input_pos: int = 0, pop_pos: int = -1
) -> None:
varname = f"conv_{self._var_count}"
node1.input.pop(pop_pos)
node1.output = [varname]
node2.input[input_pos] = varname
self._var_count += 1
@staticmethod
def _def_use(nodes: Iterable) -> Tuple[dict, dict]:
"""Computes def/use relation from a list of node.
This method is duck-typed and operates on any node defining .input and .output.
"""
defs, uses = {}, defaultdict(list)
for n in nodes:
for input_ in n.input:
uses[input_].append(n)
for output in n.output:
defs[output] = n
return defs, uses
def replace_graph_with_node_(graph: nx.DiGraph, subgraph: Iterable, node) -> nx.DiGraph:
left_neighbors, right_neighbors = set(), set()
for n in subgraph:
left_neighbors.update(from_ for from_, to in graph.in_edges(n))
right_neighbors.update(to for from_, to in graph.out_edges(n))
graph.remove_node(n)
for n in left_neighbors:
if n in graph:
graph.add_edge(n, node)
for n in right_neighbors:
if n in graph:
graph.add_edge(node, n)
return graph
def replace_node_with_chain_(graph: nx.DiGraph, node, chain: Iterable) -> nx.DiGraph:
chain = list(chain)
if not chain:
graph.remove_node(node)
return graph
for n1, n2 in zip(chain, chain[1:]):
graph.add_edge(n1, n2) # Add the chain first
for from_, _ in graph.in_edges(node):
graph.add_edge(from_, chain[0])
for _, to in graph.out_edges(node):
graph.add_edge(chain[-1], to)
graph.remove_node(node)
return graph
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment