From 30500ca36dc4e1cd874f8d8d4bd16a4be921f934 Mon Sep 17 00:00:00 2001
From: Yifan Zhao <yifanz16@illinois.edu>
Date: Sun, 6 Dec 2020 02:23:53 -0600
Subject: [PATCH] Rewrote graph manipulation LLVM-style

---
 hpvm/projects/onnx/frontend/codegen_hpvm.py   |  87 ++--
 hpvm/projects/onnx/frontend/codegen_tensor.py |  53 +--
 hpvm/projects/onnx/frontend/graph_builder.py  | 418 +++++++++---------
 hpvm/projects/onnx/frontend/graph_ir.py       | 215 +++++----
 hpvm/projects/onnx/frontend/main.py           |  12 +-
 hpvm/projects/onnx/frontend/tensor.py         |  54 ---
 6 files changed, 409 insertions(+), 430 deletions(-)
 delete mode 100644 hpvm/projects/onnx/frontend/tensor.py

diff --git a/hpvm/projects/onnx/frontend/codegen_hpvm.py b/hpvm/projects/onnx/frontend/codegen_hpvm.py
index fc9788e243..cd9eac2330 100644
--- a/hpvm/projects/onnx/frontend/codegen_hpvm.py
+++ b/hpvm/projects/onnx/frontend/codegen_hpvm.py
@@ -1,11 +1,11 @@
 from os import PathLike
 from pathlib import Path
-from typing import Dict, List, Optional, Tuple, Union
+from typing import Dict, List, Tuple, Union
 
 import jinja2
 
 from graph_builder import DFG
-from tensor import Tensor, WeightTensor
+from graph_ir import DFGNode, TensorNode, WeightTensor
 
 TEMPLATE_FILE = "template_hpvm.cpp"
 loader = jinja2.FileSystemLoader(searchpath=Path(__file__).parent)
@@ -13,7 +13,7 @@ template_env = jinja2.Environment(loader=loader, trim_blocks=True)
 template = template_env.get_template(TEMPLATE_FILE)
 
 
-class HpvmCodeGen:
+class CodeGen:
     def __init__(
         self,
         dfg: DFG,
@@ -23,23 +23,27 @@ class HpvmCodeGen:
         prefix: str = None,
     ):
         self.dfg = dfg
-        self.tensors = dfg.tensors
         self.var_count = 0
         self.output_dir = Path(output_dir)
         self.prefix = prefix
         # Some reasoning of input information
-        input_arg, input_tensor = self.dfg.discover_input_var()
-        self.input_name = input_arg
+        assert len(self.dfg.inputs) == 1
+        input_tensor = self.dfg.inputs[0]
+        self.input_name = input_tensor.name
         self.input_shape = input_tensor.shape[1:]
         self.input_size = input_size
         self.batch_size = batch_size or input_size
-        # self.variables is a "onnx name to our name" map
+        # self.variables is a "node to our name" map
         # Each value is (varname, bool) and the bool indicates
         # "is root node input" or not.
         IdenT = Union[str, int]
-        root_args = sorted([t.name for t in self.tensors.values()])
-        self.variables: Dict[str, Tuple[IdenT, bool]] = {
-            f_name: (index, True) for index, f_name in enumerate(root_args)
+        self.root_args = sorted(
+            [n for n in dfg.traverse_order if isinstance(n, TensorNode)],
+            key=lambda n: n.name,
+        )
+        self.weights = [n for n in self.root_args if isinstance(n, WeightTensor)]
+        self.variables: Dict[DFGNode, Tuple[IdenT, bool]] = {
+            f_name: (index, True) for index, f_name in enumerate(self.root_args)
         }
 
     ################################################
@@ -51,15 +55,29 @@ class HpvmCodeGen:
         self.var_count += 1
         return varname
 
-    ################################################
-    # CodeGen functions
-    ################################################
+    @classmethod
+    def emit_weights(cls, weights: List[WeightTensor]) -> List[dict]:
+        ret = []
+        for weight in weights:
+            name = cls.make_c_identifier(weight.name)
+            file_path = f"{weight.new_name}_path.bin"
+            ret.append({"name": name, "shape": weight.shape, "filename": file_path})
+        return ret
+
+    @staticmethod
+    def make_c_identifier(name: str) -> str:
+        name = name.replace(".", "_")
+        if name[0].isnumeric():
+            name = "_" + name
+        return name
+
 
-    def _emit_hpvm_node_edges(self, input_vars: List[str]) -> List[dict]:
+class HpvmCodeGen(CodeGen):
+    def _emit_hpvm_node_edges(self, input_vars: List[DFGNode]) -> List[dict]:
         ret = []
         it = 0
-        for onnx_var_name in input_vars:
-            hpvm_var_name, is_root_input = self.variables[onnx_var_name]
+        for node in input_vars:
+            hpvm_var_name, is_root_input = self.variables[node]
             if is_root_input:
                 assert isinstance(hpvm_var_name, int)
                 ret.append(
@@ -75,20 +93,23 @@ class HpvmCodeGen:
     def emit_hpvm_node_structures(self) -> List[dict]:
         node_envs = []
         for node in self.dfg.traverse_order:
+            if isinstance(node, TensorNode):
+                continue
+            inputs = self.dfg.node_args(node)
             func_name, extra_args = node.hpvm_codegen()
             if func_name == "":  # No code generation
                 # Node must have single input, we equate the output to
                 # the input and skip code generation.
-                assert len(node.input) == 1 and len(node.output) == 1
-                self.variables[node.output[0]] = self.variables[node.input[0]]
+                assert len(inputs) == 1
+                self.variables[node] = self.variables[inputs[0]]
                 continue
             varname = self._allocate_varname()
-            self.variables[node.output[0]] = varname, False  # not root-node arg
+            self.variables[node] = varname, False  # not root-node arg
             node_envs.append(
                 {
                     "name": varname,
-                    "input_size": len(node.input),
-                    "edges": self._emit_hpvm_node_edges(node.input),
+                    "input_size": len(inputs),
+                    "edges": self._emit_hpvm_node_edges(inputs),
                     "call_name": func_name,
                     "call_args": extra_args,
                 }
@@ -97,8 +118,8 @@ class HpvmCodeGen:
 
     def emit_root_io(self) -> Tuple[List[str], str]:
         input_args = [
-            make_c_identifier(name)
-            for name, (_, is_root) in self.variables.items()
+            self.make_c_identifier(node.name)
+            for node, (_, is_root) in self.variables.items()
             if is_root
         ]
         output_arg = self.variables[self.dfg.output][0]
@@ -107,7 +128,7 @@ class HpvmCodeGen:
     def compile(self) -> None:
         nodes = self.emit_hpvm_node_structures()
         inputs, output = self.emit_root_io()
-        weights = emit_weights(self.tensors)
+        weights = self.emit_weights(self.weights)
         prefix = self.prefix or self.output_dir
         with open(self.output_dir / "hpvm_src.cc", "w") as f:
             f.write(
@@ -123,21 +144,3 @@ class HpvmCodeGen:
                     prefix=prefix,
                 )
             )
-
-
-def make_c_identifier(name: str) -> str:
-    name = name.replace(".", "_")
-    if name[0].isnumeric():
-        name = "_" + name
-    return name
-
-
-def emit_weights(tensors: Dict[str, Tensor]) -> List[dict]:
-    ret = []
-    for name, tensor in tensors.items():
-        if not isinstance(tensor, WeightTensor):
-            continue
-        name = make_c_identifier(name)
-        file_path = f"{tensor.new_name}_path.bin"
-        ret.append({"name": name, "shape": tensor.shape, "filename": file_path})
-    return ret
diff --git a/hpvm/projects/onnx/frontend/codegen_tensor.py b/hpvm/projects/onnx/frontend/codegen_tensor.py
index 68ea009266..f0a7310605 100644
--- a/hpvm/projects/onnx/frontend/codegen_tensor.py
+++ b/hpvm/projects/onnx/frontend/codegen_tensor.py
@@ -1,11 +1,10 @@
-from os import PathLike
 from pathlib import Path
-from typing import Dict, List, Optional, Union
+from typing import Dict, List
 
 import jinja2
 
-from codegen_hpvm import emit_weights, make_c_identifier
-from graph_builder import DFG
+from codegen_hpvm import CodeGen
+from graph_ir import DFGNode, TensorNode
 
 TEMPLATE_FILE = "template_tensor.cpp"
 loader = jinja2.FileSystemLoader(searchpath=Path(__file__).parent)
@@ -13,32 +12,13 @@ template_env = jinja2.Environment(loader=loader, trim_blocks=True)
 template = template_env.get_template(TEMPLATE_FILE)
 
 
-class TensorCodeGen:
-    def __init__(
-        self, dfg: DFG, output_dir: PathLike, input_shape: Optional[List[int]] = None
-    ):
-        self.tensors = dfg.tensors
-        self.dfg = dfg
-        self.var_count = 0
-        self.output_dir = Path(output_dir)
-        input_arg, input_tensor = self.dfg.discover_input_var()
-        self.input_info = input_arg, (input_shape or input_tensor.shape)
-        # self.variables is a "onnx name to our name" map
-        # Each value is (varname, bool) and the bool indicates
-        # "is root node input" or not.
-        IdenT = Union[str, int]
-        self.variables: Dict[str, IdenT] = {
-            k: make_c_identifier(k) for k in self.tensors
+class TensorCodeGen(CodeGen):
+    def __init__(self, *args, **kwargs):
+        super().__init__(*args, **kwargs)
+        self.variables: Dict[DFGNode, str] = {
+            n: self.make_c_identifier(n.name) for n in self.root_args
         }
 
-    ################################################
-    # Aux functions
-    ################################################
-    def _allocate_varname(self) -> str:
-        varname = f"var_{self.var_count}"
-        self.var_count += 1
-        return varname
-
     ################################################
     # CodeGen functions
     ################################################
@@ -46,16 +26,19 @@ class TensorCodeGen:
     def emit_graph(self) -> List[dict]:
         graph_code = []
         for node in self.dfg.traverse_order:
+            if isinstance(node, TensorNode):
+                continue
+            inputs = self.dfg.node_args(node)
             func_name, extra_args = node.codegen()
             if func_name == "":  # No code generation
                 # Node must have single input, we equate the output to
                 # the input and skip code generation.
-                assert len(node.input) == 1 and len(node.output) == 1
-                self.variables[node.output[0]] = self.variables[node.input[0]]
+                assert len(inputs) == 1
+                self.variables[node] = self.variables[inputs[0]]
                 continue
-            input_args = [self.variables[arg] for arg in node.input] + extra_args
             varname = self._allocate_varname()
-            self.variables[node.output[0]] = varname
+            self.variables[node] = varname
+            input_args = [self.variables[n] for n in inputs] + extra_args
             graph_code.append(
                 {"output": varname, "inputs": input_args, "function": func_name}
             )
@@ -72,11 +55,11 @@ class TensorCodeGen:
         with open(self.output_dir / "src.cc", "w") as f:
             f.write(
                 template.render(
-                    input=self.input_info[0],
-                    input_shape=self.input_info[1],
+                    input=self.input_name,
+                    input_shape=self.input_shape,
                     output=output_arg,
                     graph_code=graph_code,
-                    weights=emit_weights(self.tensors),
+                    weights=self.emit_weights(self.weights),
                     output_dir=self.output_dir,
                 )
             )
diff --git a/hpvm/projects/onnx/frontend/graph_builder.py b/hpvm/projects/onnx/frontend/graph_builder.py
index 87d97f5efd..ccb66bc7e4 100644
--- a/hpvm/projects/onnx/frontend/graph_builder.py
+++ b/hpvm/projects/onnx/frontend/graph_builder.py
@@ -1,202 +1,157 @@
-from collections import defaultdict
+from collections import defaultdict, namedtuple
 from onnx_attr import node_attr_to_dict
 from os import PathLike
 from pathlib import Path
-from typing import Dict, Iterable, List, Optional, Tuple
+from typing import Dict, Iterable, List, Optional, Tuple, Union
 
 import networkx as nx
 import onnx
 
 import graph_ir as g
-from tensor import InputTensor, Tensor, WeightTensor
 
-ModelT = onnx.ModelProto
 GraphT = onnx.GraphProto
 NodeT = onnx.NodeProto
 NodeT.__hash__ = lambda self: id(self)
 
 
-class GraphBuilder:
-    def __init__(self, model: ModelT):
-        self._check_model(model)
-        self.tensors = self._extract_tensors_from_graph(model.graph)
-        self.dfg = DFG(model.graph, self.tensors)
+class MarkedSubGraph:
+    def __init__(self, subgraph: nx.DiGraph, entry_edges, exit) -> None:
+        assert all(to in subgraph for _, to, _ in entry_edges)
+        assert exit in subgraph
+        self.subgraph, self.exit = subgraph, exit
+        self.entry_edges = [(f, t, {"index": i}) for f, t, i in entry_edges]
 
-    ################################################
-    # Aux functions for graph building
-    ################################################
+    @classmethod
+    def idiomatic_1to2(cls, node1, node2, predecessors):
+        p0, p1, p2 = predecessors
+        graph = nx.DiGraph()
+        graph.add_edge(node1, node2, index=0)
+        return cls(graph, [(p0, node1, 0), (p1, node1, 1), (p2, node2, 1)], node2)
 
-    @staticmethod
-    def _check_model(onnx_model: ModelT):
-        import warnings
-        from onnx import checker, onnx_cpp2py_export
 
-        if hasattr(checker, "check_model"):
-            # try use onnx's own model checker before converting any model
-            try:
-                checker.check_model(onnx_model)
-            except onnx_cpp2py_export.checker.ValidationError as e:
-                warnings.warn(str(e))
-
-    @staticmethod
-    def _extract_tensors_from_graph(onnx_graph: GraphT) -> Dict[str, Tensor]:
-        tensors = {}
-        # parse weight
-        weight_cnt = 0
-        for weight_tensor in onnx_graph.initializer:
-            tensors[weight_tensor.name] = WeightTensor(
-                weight_tensor, f"weight_{weight_cnt}"
-            )
-            weight_cnt += 1
-        # parse input
-        input_cnt = 0
-        for input_ in onnx_graph.input:
-            if input_.name in tensors:
-                continue
-            tensors[input_.name] = InputTensor(input_, f"input_{input_cnt}")
-            input_cnt += 1
-        return tensors
-
-    def dump_weights(self, output_dir: PathLike) -> None:
-        output_dir = Path(output_dir)
-        for tensor in self.tensors.values():
-            if not isinstance(tensor, WeightTensor):
-                continue
-            tensor.dump_weight(output_dir / (tensor.new_name + "_path.bin"))
+EmitNodeT = Union[MarkedSubGraph, g.DFGNode]
 
 
 class DFG(object):
-    def __init__(self, graph: GraphT, tensors: Dict[str, Tensor]):
-        if len(graph.output) > 1:
-            raise ValueError("Only single-output graph is supported")
-        self.output: str = graph.output[0].name
+    def __init__(self, graph: GraphT):
+        self._check_model(graph)
         self._var_count = 0
-        self.tensors = tensors
         onnx_graph = self._build_onnx_dfg(graph)
-        self._graph = self._build_dfg(onnx_graph)
-        self._dce()  # Remove unused values
+        self.graph = self._build_dfg(onnx_graph)
+        self.inputs, self.output = self._dce_get_io_info()
 
     ################ Interfaces:
 
     @property
     def traverse_order(self) -> List[g.DFGNode]:
-        return list(nx.topological_sort(self._graph))
+        return list(nx.topological_sort(self.graph))
 
-    def discover_input_var(self) -> Tuple[str, InputTensor]:
-        """Guess which input tensor is the "input" to the ONNX model.
-        This is useful when we batch through the input tensor."""
+    def node_args(self, node):
+        sorted_edges = sorted(self.graph.in_edges(node, "index"), key=lambda p: p[2])
+        return [e[0] for e in sorted_edges]
 
-        inputs = [
-            (name, tensor)
-            for name, tensor in self.tensors.items()
-            if isinstance(tensor, InputTensor)
-        ]
-        assert len(inputs) == 1
-        return inputs[0]
+    def dump_weights(self, output_dir: PathLike) -> None:
+        output_dir = Path(output_dir)
+        for node in self.graph.nodes:
+            if not isinstance(node, g.WeightTensor):
+                continue
+            node.dump_weight(output_dir / (node.new_name + "_path.bin"))
 
     ################ Internal methods (high-level):
 
+    @staticmethod
+    def _check_model(onnx_graph: GraphT):
+        import warnings
+        from onnx import checker, onnx_cpp2py_export
+
+        # try use onnx's own model checker before converting any model
+        try:
+            checker.check_graph(onnx_graph)
+        except onnx_cpp2py_export.checker.ValidationError as e:
+            warnings.warn(str(e))
+        if any(len(n.output) > 1 for n in onnx_graph.node):
+            raise ValueError("All node must have single output")
+        if len(onnx_graph.output) > 1:
+            raise ValueError("Graph must have single output")
+
     def _build_onnx_dfg(self, graph: GraphT) -> nx.DiGraph:
         """Creates a DiGraph (by use-def relation) of onnx nodes from onnx GraphProto.
         DiGraph is easier to use as a graph compared to GraphProto where use-def is implicit."""
 
         ret_graph = nx.DiGraph()
-        onnx_defs, onnx_uses = self._def_use(graph.node)
+        onnx_defs, onnx_uses = def_use(graph.node)
+        tensors = extract_tensors_from_graph(graph)
         ret_graph.add_nodes_from(graph.node)
         for onnx_value_name, use_nodes in onnx_uses.items():
-            if onnx_value_name not in onnx_defs:
-                continue
-            def_node = onnx_defs[onnx_value_name]
-            for use_node in use_nodes:
-                ret_graph.add_edge(def_node, use_node)
+            def_node = onnx_defs.get(onnx_value_name)
+            if def_node is None:
+                def_node = tensors[onnx_value_name]
+            for use_node, used_at_narg in use_nodes:
+                ret_graph.add_edge(def_node, use_node, index=used_at_narg)
         return ret_graph
 
     def _build_dfg(self, onnx_graph: nx.DiGraph) -> nx.DiGraph:
-        onnx_graph = self._detect_flatten(onnx_graph)
+        onnx_graph = detect_flatten(onnx_graph)
         # For each onnx node, generate our nodes
-        ret_graph = onnx_graph.copy()
-        error_nodes = []
-        for onnx_node in onnx_graph.nodes:
-            if isinstance(onnx_node, g.DFGNode):
-                continue
-            our_nodes = self._emit_node(onnx_node)
+        node_to_nodes, error_nodes = {}, []
+        for onnx_node in nx.topological_sort(onnx_graph):
+            our_nodes = self._emit_node(onnx_graph, onnx_node)
             if our_nodes is None:
                 error_nodes.append(onnx_node)
             else:
-                replace_node_with_chain_(ret_graph, onnx_node, our_nodes)
+                node_to_nodes[onnx_node] = our_nodes
         if error_nodes:
             error_repr = [f"{n.name}({n.op_type})" for n in error_nodes]
             if len(error_nodes) > 10:  # Magic number
                 raise ValueError(f"Unsupported operators (first 10): {error_repr[:10]}")
             else:
                 raise ValueError(f"Unsupported operators: {error_repr}")
-        return ret_graph
-
-    def _dce(self):
-        _, uses = self._def_use(self._graph.nodes)
-        used_values = set(uses.keys())
-        unused_values = set(self.tensors.keys()) - used_values
-        for k in unused_values:
-            self.tensors.pop(k)
-
-    def _detect_flatten(self, graph: nx.DiGraph) -> nx.DiGraph:
-        """Look for a shape-gather-unsqueeze-concat-reshape chain and replace that with flatten."""
+        return build_graph_with_mapping(onnx_graph, node_to_nodes)
+
+    def _dce_get_io_info(self):
+        inputs = [n for n in self.graph if isinstance(n, g.InputTensor)]
+        inputs_set = set(inputs)
+        reachables = set()
+        for component in nx.connected_components(self.graph.to_undirected()):
+            # If any inputs goes into this subgraph, it's alive.
+            if set(component).intersection(inputs_set):
+                reachables.update(component)
+        unreachables = set(self.graph) - reachables
+        # Remove nodes unreachable from input
+        self.graph.remove_nodes_from(unreachables)
+        # Then outputs are nodes with out_degree = 0
+        outputs = [n for n in self.graph if self.graph.out_degree[n] == 0]
+        assert len(outputs) == 1
+        return inputs, outputs[0]
 
-        def get_def_at_pos(node, pos: int):
-            from_, to = list(graph.in_edges(node))[pos]
-            return from_
-
-        for node in list(graph.nodes):
-            if node.op_type != "Shape":
-                continue
-            ng = self.get_next_in_chain(graph, "Gather", node)
-            # Find the second input argument to Gather (will be a Constant node)
-            # and take that away as well.
-            nct = get_def_at_pos(ng, 1)
-            nu = self.get_next_in_chain(graph, "Unsqueeze", ng)
-            nc = self.get_next_in_chain(graph, "Concat", nu)
-            nr = self.get_next_in_chain(graph, "Reshape", nc)
-            if nr is None:
-                continue
-            nodes = [node, ng, nct, nu, nc, nr]
-            gen_node = g.FlattenNode.from_onnx_idiom(nodes)
-            graph = replace_graph_with_node_(graph, nodes, gen_node)
-        return graph
-
-    # This should be the place where partial evaluation happens
-    def _emit_node(self, onnx_node: NodeT) -> Optional[List[g.DFGNode]]:
-        if onnx_node.op_type == "Conv":
-            weight_tensor = self.tensors[onnx_node.input[1]]
-            assert isinstance(weight_tensor, WeightTensor)
+    @staticmethod
+    def _emit_node(in_graph: nx.DiGraph, node: NodeT) -> Optional[EmitNodeT]:
+        predec = sorted_inputs(in_graph, node)
+        if isinstance(node, g.DFGNode):
+            # Directly add node into return graph.
+            return node
+
+        attrs = node_attr_to_dict(node)
+        if node.op_type == "Conv":
+            weight_tensor = predec[1]
+            assert isinstance(weight_tensor, g.WeightTensor)
             if len(weight_tensor.shape) != 4:
                 return None  # Only supports 2D conv
-            if len(onnx_node.input) == 2:
-                return [g.Conv2DNode(onnx_node)]
-            else:
-                # Add an intermediate var between conv and add
-                conv_node = g.Conv2DNode(onnx_node)
-                bias_node = g.BiasAddNode(onnx_node)
-                self._split_node_args(conv_node, bias_node)
-                return [conv_node, bias_node]
-        elif onnx_node.op_type in ("MatMul", "Gemm"):
-            if onnx_node.op_type == "Gemm":
-                # Some tensors may need transposing
-                attrs = node_attr_to_dict(onnx_node)
-                # We cannot transpose input tensor (need a transpose op)
-                assert not attrs.get("transA", False)
-                # But we can transpose weight tensor before emitting it
-                if attrs.get("transB", False):
-                    weight_tensor = self.tensors[onnx_node.input[1]]
-                    assert isinstance(weight_tensor, WeightTensor)
-                    weight_tensor.transpose_()
-            if len(onnx_node.input) == 2:
-                return [g.MatMulNode(onnx_node)]
-            else:
-                # Add an intermediate var between matmul and add
-                mul_node = g.MatMulNode(onnx_node)
-                bias_node = g.BiasAddNode(onnx_node)
-                self._split_node_args(mul_node, bias_node)
-                return [mul_node, bias_node]
+            conv_node = g.Conv2DNode(node.name, attrs)
+            if len(predec) == 2:
+                return conv_node
+            # Split into conv followed by an addition
+            bias_node = g.BiasAddNode(f"Bias_{node.name.split('_')[-1]}")
+            return MarkedSubGraph.idiomatic_1to2(conv_node, bias_node, predec)
+        elif node.op_type in ("MatMul", "Gemm"):
+            mul_node = g.MatMulNode(node.name, attrs)
+            if node.op_type == "Gemm":
+                mul_node.gemm_transpose(node, predec)
+            if len(predec) == 2:
+                return mul_node
+            # Split into mul followed by an addition
+            bias_node = g.BiasAddNode(f"Bias_{node.name.split('_')[-1]}")
+            return MarkedSubGraph.idiomatic_1to2(mul_node, bias_node, predec)
         one_to_one_nodes = {
             "MaxPool": g.MaxPool2DNode,
             "AveragePool": g.AveragePool2DNode,
@@ -207,78 +162,121 @@ class DFG(object):
             "BatchNormalization": g.BatchNormalizationNode,
             "Pad": g.PadNode,
             "Identity": g.IdentityNode,
-            "Flatten": g.FlattenNode.from_single_node,
+            "Flatten": g.FlattenNode,
         }
-        if onnx_node.op_type in one_to_one_nodes:
-            return [one_to_one_nodes[onnx_node.op_type](onnx_node)]
+        if node.op_type in one_to_one_nodes:
+            return one_to_one_nodes[node.op_type](node.name, attrs)
         return None
 
-    ################ Internal methods (utils):
 
-    @staticmethod
-    def get_next_in_chain(
-        graph: nx.DiGraph, type_: str, node: Optional[NodeT]
-    ) -> Optional[NodeT]:
-        """
-        Get a unique user node of the unique output of Node `node`,
-        and return it if it has Type `type_`.
-        """
-        if node is None or len(node.output) != 1:
-            return None  # Propagates None; Unique output
-        users = list(graph.neighbors(node))
-        if len(users) != 1 or users[0].op_type != type_:
-            return None  # Unique user of the output; Correct type
-        return users[0]
-
-    def _split_node_args(
-        self, node1: g.DFGNode, node2: g.DFGNode, input_pos: int = 0, pop_pos: int = -1
-    ) -> None:
-        varname = f"conv_{self._var_count}"
-        node1.input.pop(pop_pos)
-        node1.output = [varname]
-        node2.input[input_pos] = varname
-        self._var_count += 1
+def def_use(nodes: Iterable) -> Tuple[dict, dict]:
+    """Computes def/use relation from a list of node.
+
+    This method is duck-typed and operates on any node defining .input and .output.
+    """
+    defs, uses = {}, defaultdict(list)
+    for n in nodes:
+        for i, input_ in enumerate(n.input):
+            uses[input_].append((n, i))
+        for output in n.output:
+            defs[output] = n
+    return defs, uses
+
+
+def detect_flatten(graph: nx.DiGraph) -> nx.DiGraph:
+    """Look for a shape-gather-unsqueeze-concat-reshape chain and replace that with flatten."""
+
+    for node in list(graph.nodes):
+        if node.op_type != "Shape":
+            continue
+        ng = get_next_in_chain(graph, "Gather", node)
+        # Find the second input argument to Gather (will be a Constant node)
+        # and take that away as well.
+        nct = sorted_inputs(graph, ng)[1]
+        nu = get_next_in_chain(graph, "Unsqueeze", ng)
+        nc = get_next_in_chain(graph, "Concat", nu)
+        nr = get_next_in_chain(graph, "Reshape", nc)
+        if nr is None:
+            continue
+        _, suffix = node.name.split("_")
+        gen_node = g.FlattenNode(f"Flatten_{suffix}")
+        replace_chain_with_node_(graph, [node, ng, nct, nu, nc, nr], gen_node)
+    return graph
 
-    @staticmethod
-    def _def_use(nodes: Iterable) -> Tuple[dict, dict]:
-        """Computes def/use relation from a list of node.
-
-        This method is duck-typed and operates on any node defining .input and .output.
-        """
-        defs, uses = {}, defaultdict(list)
-        for n in nodes:
-            for input_ in n.input:
-                uses[input_].append(n)
-            for output in n.output:
-                defs[output] = n
-        return defs, uses
-
-
-def replace_graph_with_node_(graph: nx.DiGraph, subgraph: Iterable, node) -> nx.DiGraph:
-    left_neighbors, right_neighbors = set(), set()
-    for n in subgraph:
-        left_neighbors.update(from_ for from_, to in graph.in_edges(n))
-        right_neighbors.update(to for from_, to in graph.out_edges(n))
-        graph.remove_node(n)
-    for n in left_neighbors:
-        if n in graph:
-            graph.add_edge(n, node)
-    for n in right_neighbors:
-        if n in graph:
-            graph.add_edge(node, n)
+
+def get_next_in_chain(
+    graph: nx.DiGraph, type_: str, node: Optional[NodeT]
+) -> Optional[NodeT]:
+    """
+    Get a unique user node of the unique output of Node `node`,
+    and return it if it has Type `type_`.
+    """
+    if node is None or len(node.output) != 1:
+        return None  # Propagates None; Unique output
+    users = list(graph.neighbors(node))
+    if len(users) != 1 or users[0].op_type != type_:
+        return None  # Unique user of the output; Correct type
+    return users[0]
+
+
+def replace_chain_with_node_(graph: nx.DiGraph, chain: list, node) -> nx.DiGraph:
+    inputs = sorted_inputs(graph, chain[0])
+    succ = graph.out_edges(chain[-1], "index")
+    for i, n in enumerate(inputs):
+        graph.add_edge(n, node, index=i)
+    for _, to, index in succ:
+        graph.add_edge(node, to, index=index)
+    graph.remove_nodes_from(chain)
     return graph
 
 
-def replace_node_with_chain_(graph: nx.DiGraph, node, chain: Iterable) -> nx.DiGraph:
-    chain = list(chain)
-    if not chain:
-        graph.remove_node(node)
-        return graph
-    for n1, n2 in zip(chain, chain[1:]):
-        graph.add_edge(n1, n2)  # Add the chain first
-    for from_, _ in graph.in_edges(node):
-        graph.add_edge(from_, chain[0])
-    for _, to in graph.out_edges(node):
-        graph.add_edge(chain[-1], to)
-    graph.remove_node(node)
+def build_graph_with_mapping(
+    graph: nx.DiGraph, node_mapping: Dict[NodeT, EmitNodeT]
+) -> nx.DiGraph:
+    graph = graph.copy()
+    single_node, multi_node = {}, {}
+    for replace_node, by_node in node_mapping.items():
+        if isinstance(by_node, g.DFGNode):
+            single_node[replace_node] = by_node
+        else:
+            multi_node[replace_node] = by_node
+    # We do one-to-many replacements first
+    # because their predecessors are specified as onnx nodes.
+    for replace_node, subgraph in multi_node.items():
+        # Add subgraph itself
+        graph = nx.compose(graph, subgraph.subgraph)
+        # Add in edges
+        graph.add_edges_from(subgraph.entry_edges)
+        # Add out edges
+        succ = graph.out_edges(replace_node, "index")
+        for _, to, index in succ:
+            graph.add_edge(subgraph.exit, to, index=index)
+        # Remove old node
+        graph.remove_node(replace_node)
+    # Then do all one-to-one replacements.
+    graph = nx.relabel_nodes(graph, single_node)
     return graph
+
+
+def extract_tensors_from_graph(onnx_graph: GraphT) -> Dict[str, g.TensorNode]:
+    tensors = {}
+    # parse weight
+    weight_cnt = 0
+    for weight_tensor in onnx_graph.initializer:
+        tensors[weight_tensor.name] = g.WeightTensor(
+            weight_tensor, f"weight_{weight_cnt}"
+        )
+        weight_cnt += 1
+    # parse input
+    input_cnt = 0
+    for input_ in onnx_graph.input:
+        if input_.name in tensors:
+            continue
+        tensors[input_.name] = g.InputTensor(input_, f"input_{input_cnt}")
+        input_cnt += 1
+    return tensors
+
+
+def sorted_inputs(graph: nx.DiGraph, node):
+    sorted_edges = sorted(graph.in_edges(node, "index"), key=lambda p: p[2])
+    return [e[0] for e in sorted_edges]
diff --git a/hpvm/projects/onnx/frontend/graph_ir.py b/hpvm/projects/onnx/frontend/graph_ir.py
index 41cd9c185b..9ffff9944c 100644
--- a/hpvm/projects/onnx/frontend/graph_ir.py
+++ b/hpvm/projects/onnx/frontend/graph_ir.py
@@ -3,6 +3,8 @@
 ################################################
 
 
+import abc
+from os import PathLike
 from typing import List
 
 import onnx
@@ -10,12 +12,11 @@ import onnx
 from onnx_attr import node_attr_to_dict
 
 
-class DFGNode:
-    def __init__(self, onnx_node: onnx.NodeProto):
-        self.name = onnx_node.name
-        self.op_type = onnx_node.op_type
-        self.input = onnx_node.input
-        self.output = onnx_node.output
+class DFGNode(abc.ABC):
+    op_type = ""
+
+    def __init__(self, name: str, attrs: dict = {}):
+        self.name, self.attrs = name, attrs
 
     def codegen(self):
         return "", []
@@ -24,7 +25,7 @@ class DFGNode:
         return "", []
 
     def __repr__(self):
-        return f"{self.__class__.__name__}({self.input}) -> {self.output}"
+        return f"{self.op_type}({self.name})"
 
 
 ################################################
@@ -32,53 +33,71 @@ class DFGNode:
 ################################################
 
 
-class AddNode(DFGNode):
-    def codegen(self):
-        return "tensorAdd", []
+class TensorNode(DFGNode):
+    def __init__(self, proto: onnx.TensorProto, new_name: str):
+        if not proto.name.strip():
+            raise ValueError("Tensor's name is required.")
+        super().__init__(proto.name, {})
+        self.new_name = new_name
 
-    def hpvm_codegen(self):
-        return "__visc__tensor_add", []
+    def __str__(self):
+        return f"{self.__class__.__name__}: {self.name}"
 
+    __repr__ = __str__
 
-class BiasAddNode(DFGNode):
-    def __init__(self, onnx_conv_node: onnx.NodeProto):
-        super().__init__(onnx_conv_node)
-        self.op_type = "BiasAdd"
-        self.input = [onnx_conv_node.output[0], onnx_conv_node.input[2]]
 
-    def codegen(self):
-        return "tensorAdd", []
+class InputTensor(TensorNode):
+    op_type = "InputTensor"
 
-    def hpvm_codegen(self):
-        return "__visc__tensor_add", []
+    def __init__(self, input_proto: onnx.TensorProto, new_name: str):
+        super().__init__(input_proto, new_name)
+        # get type of input tensor
+        tensor_type = input_proto.type.tensor_type
+        # check if it has a shape:
+        shape = tensor_type.shape
+        self.shape: List[int] = [d.dim_value for d in shape.dim]
 
 
-class MatMulNode(DFGNode):
-    def codegen(self):
-        return "tensorGemmGPU", []
+class WeightTensor(TensorNode):
+    op_type = "WeightTensor"
 
-    def hpvm_codegen(self):
-        return "__visc__tensor_mul", []
+    def __init__(self, weight_proto: onnx.TensorProto, new_name: str):
+        from onnx import numpy_helper
 
+        super().__init__(weight_proto, new_name)
+        self.shape = []
+        self.input_data = numpy_helper.to_array(weight_proto)
+        sh = self.input_data.shape
+        if len(sh) == 1:
+            self.shape = [1, sh[0], 1, 1]
+        elif len(sh) == 2:
+            self.shape = [1, 1, sh[0], sh[1]]
+        elif len(sh) == 4:
+            self.shape = [sh[0], sh[1], sh[2], sh[3]]
+        else:
+            self.shape = [1] * 4
 
-class SoftMaxNode(DFGNode):
-    def codegen(self):
-        return "tensorSoftmax", []
+    def dump_weight(self, file_name: PathLike):
+        self.input_data.tofile(file_name)
 
-    def hpvm_codegen(self):
-        return "__visc__tensor_softmax", []
+    def transpose_(self):
+        if len(self.input_data.shape) != 2:
+            raise ValueError("Can only transpose 2D array")
+        self.input_data = self.input_data.T
+        self.shape[3], self.shape[2] = self.shape[2:]
 
 
 class Conv2DNode(DFGNode):
-    def __init__(self, onnx_node: onnx.NodeProto):
-        super().__init__(onnx_node)
-        attrs = node_attr_to_dict(onnx_node)
-        padding = attrs["pads"]
+    op_type = "Conv2D"
+
+    def __init__(self, name: str, attrs: dict):
+        super().__init__(name, attrs)
+        padding = self.attrs["pads"]
         assert len(padding) == 4, "2D convolution must have 4 padding values"
         if any(p != padding[0] for p in padding[1:]):
             raise ValueError("Convolution with different padding is unsupported")
         self.padding = padding[0]
-        self.strides = attrs["strides"]
+        self.strides = self.attrs["strides"]
 
     def codegen(self):
         return (
@@ -93,14 +112,14 @@ class Conv2DNode(DFGNode):
         )
 
 
-class MaxPool2DNode(DFGNode):
-    def __init__(self, onnx_node: onnx.NodeProto):
-        super().__init__(onnx_node)
-        attr = node_attr_to_dict(onnx_node)
-        self.strides = attr["strides"]
-        self.pool_size = attr["kernel_shape"]
+class _Pool2DNode(DFGNode, abc.ABC):
+    pool_type = "0"
+
+    def __init__(self, name: str, attrs: dict):
+        super().__init__(name, attrs)
+        self.strides = self.attrs["strides"]
+        self.pool_size = self.attrs["kernel_shape"]
         self.padding = 0
-        self.pool_type = "0"
 
     def codegen(self):
         return (
@@ -121,35 +140,75 @@ class MaxPool2DNode(DFGNode):
         )
 
 
-class AveragePool2DNode(DFGNode):
-    def __init__(self, onnx_node: onnx.NodeProto):
-        super().__init__(onnx_node)
-        attr = node_attr_to_dict(onnx_node)
-        self.strides = attr["strides"]
-        self.pool_size = attr["kernel_shape"]
-        self.padding = 0
-        self.pool_type = "1"
+class MaxPool2DNode(_Pool2DNode):
+    pool_type = "0"
+    op_type = "MaxPool2D"
+
+
+class AveragePool2DNode(_Pool2DNode):
+    pool_type = "1"
+    op_type = "AveragePool2D"
+
+
+class BiasAddNode(DFGNode):
+    op_type = "BiasAdd"
 
     def codegen(self):
-        return (
-            "tensorPooling",
-            [
-                self.pool_type,
-                *self.pool_size,
-                self.padding,
-                self.padding,
-                *self.strides,
-            ],
-        )
+        return "tensorAdd", []
 
     def hpvm_codegen(self):
-        return (
-            "__visc__tensor_pool_avg",
-            [*self.pool_size, self.padding, self.padding, *self.strides],
-        )
+        return "__visc__tensor_add", []
+
+
+class MatMulNode(DFGNode):
+    op_type = "MatMul"
+
+    def codegen(self):
+        return "tensorGemmGPU", []
+
+    def hpvm_codegen(self):
+        return "__visc__tensor_mul", []
+
+    @staticmethod
+    def gemm_transpose(onnx_node, predec):
+        def _transpose(weight):
+            if not isinstance(weight, WeightTensor):
+                raise ValueError(
+                    f"Cannot transpose non-const {weight} (transpose op needed)"
+                )
+            weight.transpose_()
+
+        # Some tensors may need transposing
+        attrs = node_attr_to_dict(onnx_node)
+        if attrs.get("transA", False):
+            _transpose(predec[0])
+        if attrs.get("transB", False):
+            _transpose(predec[1])
+
+
+class SoftMaxNode(DFGNode):
+    op_type = "SoftMax"
+
+    def codegen(self):
+        return "tensorSoftmax", []
+
+    def hpvm_codegen(self):
+        return "__visc__tensor_softmax", []
+
+
+class AddNode(DFGNode):
+    op_type = "Add"
+
+    def codegen(self):
+        return "tensorAdd", []
+
+    def hpvm_codegen(self):
+        return "__visc__tensor_add", []
 
 
 class ReluNode(DFGNode):
+    op_type = "ReLU"
+
     def codegen(self):
         return "tensorRelu", []
 
@@ -158,6 +217,8 @@ class ReluNode(DFGNode):
 
 
 class TanhNode(DFGNode):
+    op_type = "Tanh"
+
     def codegen(self):
         return "tensorTanh", []
 
@@ -166,10 +227,11 @@ class TanhNode(DFGNode):
 
 
 class BatchNormalizationNode(DFGNode):
-    def __init__(self, onnx_node: onnx.NodeProto):
-        super().__init__(onnx_node)
-        attr = node_attr_to_dict(onnx_node)
-        self.epsilon = attr["epsilon"]
+    op_type = "BN"
+
+    def __init__(self, name: str, attrs: dict):
+        super().__init__(name, attrs)
+        self.epsilon = self.attrs["epsilon"]
 
     def codegen(self):
         return "tensorBatchNorm", [self.epsilon]
@@ -179,20 +241,7 @@ class BatchNormalizationNode(DFGNode):
 
 
 class FlattenNode(DFGNode):
-    def __init__(self, name: str, op_type: str, input, output):
-        self.name = name
-        self.op_type = op_type
-        self.input = input
-        self.output = output
-
-    @classmethod
-    def from_single_node(cls, n: onnx.NodeProto):
-        return cls(n.name, n.op_type, n.input, n.output)
-
-    @classmethod
-    def from_onnx_idiom(cls, nodes: List[onnx.NodeProto]):
-        _, suffix = nodes[0].name.split("_")
-        return cls(f"Flatten_{suffix}", "Flatten", nodes[0].input, nodes[-1].output)
+    op_type = "Flatten"
 
 
 class ActivationNode(DFGNode):
diff --git a/hpvm/projects/onnx/frontend/main.py b/hpvm/projects/onnx/frontend/main.py
index 184c6b9ea8..25e76ae8fa 100644
--- a/hpvm/projects/onnx/frontend/main.py
+++ b/hpvm/projects/onnx/frontend/main.py
@@ -1,5 +1,5 @@
 from pathlib import Path
-from typing import List, Optional
+from typing import Optional
 
 import onnx
 
@@ -32,21 +32,21 @@ def compile(
     opset: Optional[int],
     hpvmc: bool,
 ):
-    from graph_builder import GraphBuilder
+    from graph_builder import DFG
     from codegen_tensor import TensorCodeGen
     from codegen_hpvm import HpvmCodeGen
 
     model = onnx.load(onnx_file)
     if opset is not None:
         model = check_version(model, opset)
-    graphBuilder = GraphBuilder(model)
+    dfg = DFG(model.graph)
     if hpvmc:
-        hpvmCodeGen = HpvmCodeGen(graphBuilder.dfg, output_dir, input_size, batch_size, prefix)
+        hpvmCodeGen = HpvmCodeGen(dfg, output_dir, input_size, batch_size, prefix)
         hpvmCodeGen.compile()
     else:
-        TensorCodeGen = TensorCodeGen(graphBuilder.dfg, output_dir, input_size)
+        TensorCodeGen = TensorCodeGen(dfg, output_dir, input_size)
         TensorCodeGen.compile()
-    graphBuilder.dump_weights(output_dir)
+    dfg.dump_weights(output_dir)
 
 
 def parse_args():
diff --git a/hpvm/projects/onnx/frontend/tensor.py b/hpvm/projects/onnx/frontend/tensor.py
deleted file mode 100644
index ded0162631..0000000000
--- a/hpvm/projects/onnx/frontend/tensor.py
+++ /dev/null
@@ -1,54 +0,0 @@
-from os import PathLike
-from typing import List
-
-import onnx
-from onnx import numpy_helper
-
-
-class Tensor(object):
-    def __init__(self, proto: onnx.TensorProto, new_name: str):
-        if not proto.name.strip():
-            raise ValueError("Tensor's name is required.")
-        self.name = proto.name
-        self.new_name = new_name
-
-    def __str__(self):
-        return f"{self.__class__.__name__}: {self.name}"
-
-    __repr__ = __str__
-
-
-class InputTensor(Tensor):
-    def __init__(self, input_proto: onnx.TensorProto, new_name: str):
-        super().__init__(input_proto, new_name)
-        # get type of input tensor
-        tensor_type = input_proto.type.tensor_type
-        # check if it has a shape:
-        shape = tensor_type.shape
-        self.shape: List[int] = [d.dim_value for d in shape.dim]
-
-
-# Can be either input or weight tensor
-class WeightTensor(Tensor):
-    def __init__(self, weight_proto: onnx.TensorProto, new_name: str):
-        Tensor.__init__(self, weight_proto, new_name)
-        self.shape = []
-        self.input_data = numpy_helper.to_array(weight_proto)
-        sh = self.input_data.shape
-        if len(sh) == 1:
-            self.shape = [1, sh[0], 1, 1]
-        elif len(sh) == 2:
-            self.shape = [1, 1, sh[0], sh[1]]
-        elif len(sh) == 4:
-            self.shape = [sh[0], sh[1], sh[2], sh[3]]
-        else:
-            self.shape = [1] * 4
-
-    def dump_weight(self, file_name: PathLike):
-        self.input_data.tofile(file_name)
-
-    def transpose_(self):
-        if len(self.input_data.shape) != 2:
-            raise ValueError("Can only transpose 2D array")
-        self.input_data = self.input_data.T
-        self.shape[3], self.shape[2] = self.shape[2:]
-- 
GitLab