Skip to content
Snippets Groups Projects
Commit a93c7d2a authored by Yifan Zhao's avatar Yifan Zhao
Browse files

Use an onnx simplifier to reduce work

parent b35d34c3
No related branches found
No related tags found
No related merge requests found
...@@ -41,15 +41,19 @@ class ModelExporter: ...@@ -41,15 +41,19 @@ class ModelExporter:
hpvmc: bool = True, hpvmc: bool = True,
opset: Optional[int] = None, opset: Optional[int] = None,
): ):
from onnxsim import simplify
self.tune_dataset, self.test_dataset = tune_dataset, test_dataset self.tune_dataset, self.test_dataset = tune_dataset, test_dataset
self.dataset_shape = self._check_datasets(tune_dataset, test_dataset) self.dataset_shape = self._check_datasets(tune_dataset, test_dataset)
self.dataset_size = self.dataset_shape[0] self.dataset_size = self.dataset_shape[0]
onnx_model = self._load_model(model, self.dataset_shape) onnx_model = self._load_model(model, self.dataset_shape)
if opset is not None: if opset is not None:
onnx_model = check_onnx_version(onnx_model, opset) onnx_model = check_onnx_version(onnx_model, opset)
self.onnx_model = onnx.shape_inference.infer_shapes(onnx_model) onnx_model, check = simplify(onnx_model)
assert check, "Simplified ONNX model could not be validated"
onnx_model = onnx.shape_inference.infer_shapes(onnx_model)
self.dfg = DFG(self.onnx_model.graph) self.dfg = DFG(onnx_model.graph)
self.output_dir = Path(output_dir) self.output_dir = Path(output_dir)
os.makedirs(output_dir, exist_ok=True) os.makedirs(output_dir, exist_ok=True)
self.weight_dir = self.output_dir / self.weight_dir_name self.weight_dir = self.output_dir / self.weight_dir_name
......
...@@ -128,10 +128,9 @@ class DFG(object): ...@@ -128,10 +128,9 @@ class DFG(object):
processed together, then each unprocessed node is generated into processed together, then each unprocessed node is generated into
1 or more nodes.""" 1 or more nodes."""
# Remove subgraphs that can be a single Flatten instead # Gemm in tensor_runtime does reshape automatically
onnx_graph = detect_flatten(onnx_graph) # it also doesn't have a dedicated reshape operator
# Remove subgraphs that look like padding but does nothing onnx_graph = drop_reshape_before_gemm(onnx_graph)
onnx_graph = remove_no_padding(onnx_graph)
# For each onnx node, generate our nodes # For each onnx node, generate our nodes
node_to_nodes, error_nodes = {}, [] node_to_nodes, error_nodes = {}, []
for onnx_node in nx.topological_sort(onnx_graph): for onnx_node in nx.topological_sort(onnx_graph):
...@@ -228,45 +227,22 @@ def def_use(nodes: Iterable) -> Tuple[dict, dict]: ...@@ -228,45 +227,22 @@ def def_use(nodes: Iterable) -> Tuple[dict, dict]:
return defs, uses return defs, uses
def remove_no_padding(graph: nx.DiGraph) -> nx.DiGraph: def drop_reshape_before_gemm(graph: nx.DiGraph) -> nx.DiGraph:
"""Remove subgraphs that look like padding but does nothing."""
for node in list(graph.nodes):
if node.op_type != "Pad":
continue
input_args = sorted_inputs(graph, node)
# Find the second input argument to Pad (will be a Constant node)
# and take that away as well.
nct = input_args[1]
padding = node_attr_to_dict(nct)["value"]
if any(p != 0 for p in padding):
continue
# Connect input of Pad to where output of Pad goes
succ = graph.out_edges(node, "index")
for _, to, index in succ:
graph.add_edge(input_args[0], to, index=index)
# Remove nodes
graph.remove_nodes_from([node, nct])
return graph
def detect_flatten(graph: nx.DiGraph) -> nx.DiGraph:
"""Look for a shape-gather-unsqueeze-concat-reshape chain and replace that with flatten.""" """Look for a shape-gather-unsqueeze-concat-reshape chain and replace that with flatten."""
for node in list(graph.nodes): for node in list(graph.nodes):
if node.op_type != "Shape": if node.op_type != "Reshape":
continue
reshape_input, target_shape = sorted_inputs(graph, node)
if not isinstance(target_shape, g.WeightTensor): # Not constant shape, nope
continue continue
ng = get_next_in_chain(graph, "Gather", node) n_gemm = get_next_in_chain(graph, "Gemm", node)
# Find the second input argument to Gather (will be a Constant node) if n_gemm is None:
# and take that away as well.
nct = sorted_inputs(graph, ng)[1]
nu = get_next_in_chain(graph, "Unsqueeze", ng)
nc = get_next_in_chain(graph, "Concat", nu)
nr = get_next_in_chain(graph, "Reshape", nc)
if nr is None:
continue continue
_, suffix = node.name.split("_") # Must be an (n-1)-d flatten before gemm
gen_node = g.FlattenNode(f"Flatten_{suffix}") assert list(target_shape.input_data) == [1, -1]
replace_chain_with_node_(graph, [node, ng, nct, nu, nc, nr], gen_node) # Connect input of reshape to gemm, then remove reshape
graph.add_edge(reshape_input, n_gemm, index=0)
graph.remove_node(node)
return graph return graph
...@@ -285,17 +261,6 @@ def get_next_in_chain( ...@@ -285,17 +261,6 @@ def get_next_in_chain(
return users[0] return users[0]
def replace_chain_with_node_(graph: nx.DiGraph, chain: list, node) -> nx.DiGraph:
inputs = sorted_inputs(graph, chain[0])
succ = graph.out_edges(chain[-1], "index")
for i, n in enumerate(inputs):
graph.add_edge(n, node, index=i)
for _, to, index in succ:
graph.add_edge(node, to, index=index)
graph.remove_nodes_from(chain)
return graph
def build_graph_with_mapping( def build_graph_with_mapping(
graph: nx.DiGraph, node_mapping: Dict[NodeT, EmitNodeT] graph: nx.DiGraph, node_mapping: Dict[NodeT, EmitNodeT]
) -> nx.DiGraph: ) -> nx.DiGraph:
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment