From a93c7d2a6201f1c32ff5b2f290622ab5cf8fc1b3 Mon Sep 17 00:00:00 2001 From: Yifan Zhao <yifanz16@illinois.edu> Date: Sun, 31 Jan 2021 06:47:11 -0600 Subject: [PATCH] Use an onnx simplifier to reduce work --- .../projects/torch2hpvm/torch2hpvm/compile.py | 8 ++- .../torch2hpvm/torch2hpvm/graph_builder.py | 65 +++++-------------- 2 files changed, 21 insertions(+), 52 deletions(-) diff --git a/hpvm/projects/torch2hpvm/torch2hpvm/compile.py b/hpvm/projects/torch2hpvm/torch2hpvm/compile.py index 85bf5c77fc..9cf1f1a167 100644 --- a/hpvm/projects/torch2hpvm/torch2hpvm/compile.py +++ b/hpvm/projects/torch2hpvm/torch2hpvm/compile.py @@ -41,15 +41,19 @@ class ModelExporter: hpvmc: bool = True, opset: Optional[int] = None, ): + from onnxsim import simplify + self.tune_dataset, self.test_dataset = tune_dataset, test_dataset self.dataset_shape = self._check_datasets(tune_dataset, test_dataset) self.dataset_size = self.dataset_shape[0] onnx_model = self._load_model(model, self.dataset_shape) if opset is not None: onnx_model = check_onnx_version(onnx_model, opset) - self.onnx_model = onnx.shape_inference.infer_shapes(onnx_model) + onnx_model, check = simplify(onnx_model) + assert check, "Simplified ONNX model could not be validated" + onnx_model = onnx.shape_inference.infer_shapes(onnx_model) - self.dfg = DFG(self.onnx_model.graph) + self.dfg = DFG(onnx_model.graph) self.output_dir = Path(output_dir) os.makedirs(output_dir, exist_ok=True) self.weight_dir = self.output_dir / self.weight_dir_name diff --git a/hpvm/projects/torch2hpvm/torch2hpvm/graph_builder.py b/hpvm/projects/torch2hpvm/torch2hpvm/graph_builder.py index a07c18e72d..f40e604d93 100644 --- a/hpvm/projects/torch2hpvm/torch2hpvm/graph_builder.py +++ b/hpvm/projects/torch2hpvm/torch2hpvm/graph_builder.py @@ -128,10 +128,9 @@ class DFG(object): processed together, then each unprocessed node is generated into 1 or more nodes.""" - # Remove subgraphs that can be a single Flatten instead - onnx_graph = detect_flatten(onnx_graph) - # Remove subgraphs that look like padding but does nothing - onnx_graph = remove_no_padding(onnx_graph) + # Gemm in tensor_runtime does reshape automatically + # it also doesn't have a dedicated reshape operator + onnx_graph = drop_reshape_before_gemm(onnx_graph) # For each onnx node, generate our nodes node_to_nodes, error_nodes = {}, [] for onnx_node in nx.topological_sort(onnx_graph): @@ -228,45 +227,22 @@ def def_use(nodes: Iterable) -> Tuple[dict, dict]: return defs, uses -def remove_no_padding(graph: nx.DiGraph) -> nx.DiGraph: - """Remove subgraphs that look like padding but does nothing.""" - for node in list(graph.nodes): - if node.op_type != "Pad": - continue - input_args = sorted_inputs(graph, node) - # Find the second input argument to Pad (will be a Constant node) - # and take that away as well. - nct = input_args[1] - padding = node_attr_to_dict(nct)["value"] - if any(p != 0 for p in padding): - continue - # Connect input of Pad to where output of Pad goes - succ = graph.out_edges(node, "index") - for _, to, index in succ: - graph.add_edge(input_args[0], to, index=index) - # Remove nodes - graph.remove_nodes_from([node, nct]) - return graph - - -def detect_flatten(graph: nx.DiGraph) -> nx.DiGraph: +def drop_reshape_before_gemm(graph: nx.DiGraph) -> nx.DiGraph: """Look for a shape-gather-unsqueeze-concat-reshape chain and replace that with flatten.""" - for node in list(graph.nodes): - if node.op_type != "Shape": + if node.op_type != "Reshape": + continue + reshape_input, target_shape = sorted_inputs(graph, node) + if not isinstance(target_shape, g.WeightTensor): # Not constant shape, nope continue - ng = get_next_in_chain(graph, "Gather", node) - # Find the second input argument to Gather (will be a Constant node) - # and take that away as well. - nct = sorted_inputs(graph, ng)[1] - nu = get_next_in_chain(graph, "Unsqueeze", ng) - nc = get_next_in_chain(graph, "Concat", nu) - nr = get_next_in_chain(graph, "Reshape", nc) - if nr is None: + n_gemm = get_next_in_chain(graph, "Gemm", node) + if n_gemm is None: continue - _, suffix = node.name.split("_") - gen_node = g.FlattenNode(f"Flatten_{suffix}") - replace_chain_with_node_(graph, [node, ng, nct, nu, nc, nr], gen_node) + # Must be an (n-1)-d flatten before gemm + assert list(target_shape.input_data) == [1, -1] + # Connect input of reshape to gemm, then remove reshape + graph.add_edge(reshape_input, n_gemm, index=0) + graph.remove_node(node) return graph @@ -285,17 +261,6 @@ def get_next_in_chain( return users[0] -def replace_chain_with_node_(graph: nx.DiGraph, chain: list, node) -> nx.DiGraph: - inputs = sorted_inputs(graph, chain[0]) - succ = graph.out_edges(chain[-1], "index") - for i, n in enumerate(inputs): - graph.add_edge(n, node, index=i) - for _, to, index in succ: - graph.add_edge(node, to, index=index) - graph.remove_nodes_from(chain) - return graph - - def build_graph_with_mapping( graph: nx.DiGraph, node_mapping: Dict[NodeT, EmitNodeT] ) -> nx.DiGraph: -- GitLab