diff --git a/hpvm/projects/onnx/frontend/codegen_tensor.py b/hpvm/projects/onnx/frontend/codegen_tensor.py
index 132debc269e8be8259d47d74f4f4e906c600e81a..d382db94d3b7eef9ad2bc5bfb6aa77fea5f989c6 100644
--- a/hpvm/projects/onnx/frontend/codegen_tensor.py
+++ b/hpvm/projects/onnx/frontend/codegen_tensor.py
@@ -1,141 +1,63 @@
-import sys
-import numpy as np
-import os
+from codegen_hpvm import emit_weights, get_input_args, make_c_identifier
+from os import PathLike
+from pathlib import Path
+from typing import Dict, List, Tuple, Union
 
+import jinja2
+
+from graph_builder import DFG
 from tensor import WeightTensor
-from utils import skip_layer
+
+TEMPLATE_FILE = "template_tensor.cpp"
+loader = jinja2.FileSystemLoader(searchpath="./")
+template_env = jinja2.Environment(loader=loader, trim_blocks=True)
+template = template_env.get_template(TEMPLATE_FILE)
+
 
 class TensorCodeGen(object):
-    def __init__(self, dfg, weights_dir, test_data_shape=None):
-        self.program_str = ""
-        self.graph = dfg.graph
+    def __init__(self, dfg: DFG, output_dir: PathLike, test_data_shape=None):
         self.tensors = dfg.tensors
-        self.nodes = dfg.nodes
-        self.var_cnt = -1
-        self.weights_dir = weights_dir
+        self.dfg = dfg
+        self.var_count = 0
+        self.output_dir = Path(output_dir)
         self.test_data_shape = test_data_shape
+        # self.variables is a "onnx name to our name" map
+        # Each value is (varname, bool) and the bool indicates
+        # "is root node input" or not.
+        IdenT = Union[str, int]
+        input_args = get_input_args(dfg.inputs, dfg.tensors)
+        self.variables: Dict[str, IdenT] = {k: make_c_identifier(k) for k in input_args}
+        print(self.variables)
 
     ################################################
     # Aux functions
     ################################################
-    def get_last_var(self):
-        return "var_" + str(self.var_cnt)
-
-    def get_new_var(self):
-        self.var_cnt = self.var_cnt + 1
-        return "var_" + str(self.var_cnt)
+    def _allocate_varname(self) -> str:
+        varname = f"var_{self.var_count}"
+        self.var_count += 1
+        return varname
 
     ################################################
     # CodeGen functions
     ################################################
-    def emit_weights(self):
-        weights_str = ""
-        weights_str += "std::string dir_prefix = std::string(\"" + str(self.weights_dir) + "\");\n"
-        weights_str += "std::string input_path =  dir_prefix + std::string(\"input.bin\");\n"
-        weights_str += "std::string labels_path =  dir_prefix + std::string(\"labels.bin\");\n"
-        for tensor in self.tensors.values():
-            if isinstance(tensor, WeightTensor):
-                weights_str += self.emit_single_weight(tensor)
-        self.program_str += weights_str
-
-    def emit_single_weight(self, tensor):
-        N = tensor.shape[0]
-        C = tensor.shape[1]
-        H = tensor.shape[2]
-        W = tensor.shape[3]
-        mapped_name = tensor.get_mapped_name()
-        file_path = mapped_name + "_path" 
-        unique_file_name = file_path + ".bin"
-        weight_str = "std::string " + file_path + " = " + " dir_prefix + std::string(\""
-        weight_str += unique_file_name + "\"); \n"
-        weight_str += "void* " + mapped_name + " = " + " readTrainedWeights("
-        weight_str += file_path + ".c_str(), 0," + str(N) + "," + str(C) + "," + str(H) + "," + str(W)
-        weight_str += "); \n"
-        return weight_str
-
-    def emit_graph(self):
-        for node in self.nodes:
-            # check if all inputs of this node is mapped
-            for i in node.input:
-                self.tensors[i].get_mapped_name() 
-            # set var name for output node
-            if len(node.output) > 1:
-                raise ValueError("Output number for a single layer larger than 1!")
-            if node.op_type in skip_layer or node.op_type == "BiasAdd":
-                mapped_output_name = self.get_last_var()
-            else:
-                mapped_output_name = self.get_new_var()
-            self.tensors[node.output[0]].set_mapped_name(mapped_output_name)
-            self.program_str += node.codegen(self.tensors)
-
-    def emit_header(self):
-        headers = "\n#include <stdio.h> \n"
-        headers += "#include <stdlib.h> \n"
-        headers += "#include <unistd.h> \n"
-        headers += "#include <fcntl.h> \n"
-        headers += "#include <sys/types.h> \n"
-        headers += "#include <sys/stat.h> \n"
-        headers += "#include <string.h> \n"
-        headers += "#include <chrono> \n"
-        headers += "#include \"../../tensor_runtime/include/tensor_runtime.h\" \n"
-        headers += "#include \"../include/utils.h\" \n\n"
-
-        main_func = "int main(){ \n\n"
-        initialization = "llvm_hpvm_initTensorRt(0); \n\n"
-        self.program_str += headers
-        self.program_str += main_func
-        self.program_str += initialization
-
-    def emit_footer(self, test_data=None):
-        destructors = "\nllvm_hpvm_cleanupTensorRt(); \n"
-        end_main = "\nreturn 0; \n\n}\n"
-        self.program_str += destructors
-        self.program_str += end_main
-
-    def emit_batch_loop(self, test_data_shape):
-        N = test_data_shape[0]
-        C = test_data_shape[1]
-        H = test_data_shape[2]
-        W = test_data_shape[3]
-
-        loop_str = ""
-        loop_str += "\nstartMemTracking(); \n\n"
-
-        loop_str += "int test_input_size = " + str(N) + "; \n"
-        loop_str += "int batch_size = " + str(N) + "; \n"
-        # FIXME: Ceiling for batch_count
-        loop_str += "int batch_count = test_input_size / batch_size; \n"
-        loop_str += "float final_accuracy = 0.0; \n\n"
-
-        loop_str += "for(int i = 0; i < batch_count; i++){ \n\n"
-        loop_str += "int start = i * batch_size; \n"
-        loop_str += "int end = (i + 1) * batch_size; \n"
-
-        loop_str += "\nvoid* input = readInputBatch(input_path.c_str(),0,start,end,"
-        loop_str += str(C) + "," + str(H) + "," + str(W) + "); \n\n"
-
-        self.program_str += loop_str
-
-    def emit_batch_loop_end(self):
-        end_loop_str = ""
-        end_loop_str += "\nuint32_t* labels = readLabelsBatch3(labels_path.c_str(),start,end); \n"
-        mapped_output_var = self.tensors[self.graph.output[0].name].get_mapped_name()
-        accuracy_call = "\nfloat accuracy = computeAccuracy3(labels, " + \
-            mapped_output_var + "); \n"
-        end_loop_str += accuracy_call
-        end_loop_str += "final_accuracy += accuracy; \n"
-        end_loop_str += "freeBatchMemory(); \n "
-        end_loop_str += "\n}\n\n"
-
-        end_loop_str += "final_accuracy = final_accuracy / batch_count; \n"
-        end_loop_str += "dumpFinalAccuracy(final_accuracy); \n\n"
-
-        self.program_str += end_loop_str
 
-    def emit_source_to_file(self, src_dir):
-        f = open(src_dir + "/src.cc", "w+")
-        f.write(self.program_str)
-        f.close()
+    def emit_graph(self) -> List[dict]:
+        graph_code = []
+        for node in self.dfg.traverse_order:
+            func_name, extra_args = node.codegen()
+            if func_name == "":  # No code generation
+                # Node must have single input, we equate the output to
+                # the input and skip code generation.
+                assert len(node.input) == 1 and len(node.output) == 1
+                self.variables[node.output[0]] = self.variables[node.input[0]]
+                continue
+            input_args = [self.variables[arg] for arg in node.input] + extra_args
+            varname = self._allocate_varname()
+            self.variables[node.output[0]] = varname
+            graph_code.append(
+                {"output": varname, "inputs": input_args, "function": func_name}
+            )
+        return graph_code
 
     ################################################
     # Compile is a top level function to compile an onnx model into C/C++
@@ -143,13 +65,19 @@ class TensorCodeGen(object):
     ################################################
 
     def compile(self):
-        #if os.path.exists(self.weights_dir):
-        #    raise ValueError("Weight dir existed. Compilation interrupted!")
-        #os.mkdir(self.weights_dir)
-        self.emit_header()
-        self.emit_weights()
-        self.emit_batch_loop(self.test_data_shape)
-        self.emit_graph()
-        self.emit_batch_loop_end()
-        self.emit_footer()
-        self.emit_source_to_file(self.weights_dir)
+        # nodes = self.emit_hpvm_node_structures()
+        # inputs, output = self.emit_root_io()
+        graph_code = self.emit_graph()
+        input_arg = self.dfg.discover_input_var()
+        output_arg = self.variables[self.dfg.output.name]
+        with open(self.output_dir / "src.cc", "w") as f:
+            f.write(
+                template.render(
+                    input=input_arg,
+                    input_shape=self.test_data_shape,
+                    output=output_arg,
+                    graph_code=graph_code,
+                    weights=emit_weights(self.tensors),
+                    output_dir=self.output_dir,
+                )
+            )
diff --git a/hpvm/projects/onnx/frontend/graph_builder.py b/hpvm/projects/onnx/frontend/graph_builder.py
index f7260f2b2c1edfc1f88271e99386a1db6d20319d..542cfb535809f0334ce0a5c953023b933e4f1ecb 100644
--- a/hpvm/projects/onnx/frontend/graph_builder.py
+++ b/hpvm/projects/onnx/frontend/graph_builder.py
@@ -90,8 +90,8 @@ class DFG(object):
     def __init__(self, graph: GraphT, tensors: Dict[str, Tensor]):
         if len(graph.output) > 1:
             raise ValueError("Only single-output graph is supported")
-        self.inputs: List[str] = graph.input
-        self.output: str = graph.output[0]
+        self.inputs: List[NodeT] = graph.input
+        self.output: NodeT = graph.output[0]
         self._onnx_defs, self._onnx_uses = self.def_use(graph.node)
         self._var_count = 0
         self.tensors = tensors
@@ -101,6 +101,19 @@ class DFG(object):
     def traverse_order(self) -> List[g.DFGNode]:
         return list(nx.topological_sort(self.graph))
 
+    def discover_input_var(self) -> Optional[str]:
+        """Guess which input tensor is the "input" to the ONNX model.
+
+        This is useful when we batch through the input tensor.
+        It's a guess because sometimes ONNX model put everything in the inputs
+        (weights, etcs.), and there's no apparent way to tell which one is
+        the actual input (that we can batch over)."""
+
+        if len(self.inputs) == 1:
+            return self.inputs[0].name
+        first_arg_first_node = self.traverse_order[0].input[0]
+        return first_arg_first_node
+
     @staticmethod
     def def_use(nodes: list) -> Tuple[dict, dict]:
         """Computes def/use relation from a list of node.
diff --git a/hpvm/projects/onnx/frontend/graph_ir.py b/hpvm/projects/onnx/frontend/graph_ir.py
index 5c0a2eef3cfd70e78efd95da5a36dc8cde65d68e..467ed63ccb6f292c7c7f7bc9f8a1f1c9e3b8d167 100644
--- a/hpvm/projects/onnx/frontend/graph_ir.py
+++ b/hpvm/projects/onnx/frontend/graph_ir.py
@@ -10,8 +10,8 @@ class DFGNode(object):
         self.input = onnx_node.input
         self.output = onnx_node.output
 
-    def codegen(self, tensors):
-        return "\n***Not Implemented***\n"
+    def codegen(self):
+        return "", []
 
     def hpvm_codegen(self):
         return "", []
@@ -45,14 +45,8 @@ class LogicalOpNode(DFGNode):
 
 class AddNode(DFGNode):
 
-    def codegen(self, tensors):
-        inst_str = ""
-        left_input = tensors[self.input[0]].get_mapped_name()
-        right_input = tensors[self.input[1]].get_mapped_name()
-        mapped_output_name = tensors[self.output[0]].get_mapped_name()
-        inst_str += "void* " + mapped_output_name + " = "
-        inst_str += "tensorAdd(" + left_input + ", " + right_input + "); \n"
-        return inst_str
+    def codegen(self):
+        return "tensorAdd", []
 
     def hpvm_codegen(self):
         return "__hpvm__tensor_add", []
@@ -66,43 +60,23 @@ class BiasAddNode(DFGNode):
         self.input.append(self.output[0])
         self.input.append(onnx_conv_node.input[2])
 
-    def codegen(self, tensors):
-        inst_str = ""
-        left_input = tensors[self.input[0]].get_mapped_name()
-        right_input = tensors[self.input[1]].get_mapped_name()
-        mapped_output_name = tensors[self.output[0]].get_mapped_name()
-        inst_str += mapped_output_name + " = "
-        inst_str += "tensorAdd(" + left_input + ", " + right_input + "); \n"
-        return inst_str
+    def codegen(self):
+        return "tensorAdd", []
 
     def hpvm_codegen(self):
         return "__hpvm__tensor_add", []
 
 class MatMulNode(DFGNode):
-
-    def codegen(self, tensors):
-        inst_str = ""
-        left_input = tensors[self.input[0]].get_mapped_name()
-        right_input = tensors[self.input[1]].get_mapped_name()
-        mapped_output_name = tensors[self.output[0]].get_mapped_name()
-        inst_str += "void* " + mapped_output_name + " = "
-        inst_str += "tensorGemmGPU(" + left_input + \
-            ", " + right_input + "); \n"
-        return inst_str
+    def codegen(self):
+        return "tensorGemmGPU", []
 
     def hpvm_codegen(self):
         return "__hpvm__tensor_mul", []
 
 
 class SoftMaxNode(DFGNode):
-
-    def codegen(self, tensors):
-        inst_str = ""
-        mapped_input_name = tensors[self.input[0]].get_mapped_name()
-        mapped_output_name = tensors[self.output[0]].get_mapped_name()
-        inst_str += "void* " + mapped_output_name + " = "
-        inst_str += "tensorSoftmax(" + mapped_input_name + "); \n"
-        return inst_str
+    def codegen(self):
+        return "tensorSoftmax", []
 
     def hpvm_codegen(self):
         return "__hpvm__tensor_softmax", []
@@ -127,20 +101,8 @@ class Conv2DNode(DFGNode):
                 for stride in attr.ints:
                     self.strides.append(stride)
 
-    def codegen(self, tensors):
-        inst_str = ""
-        mapped_input_name  = tensors[self.input[0]].get_mapped_name()
-        mapped_output_name = tensors[self.output[0]].get_mapped_name()
-
-        inst_str += "void* " + mapped_output_name + " = "
-        inst_str += "tensorConvolution(" + mapped_input_name  + ", "
-        inst_str += tensors[self.input[1]].get_mapped_name() + ", "
-        inst_str += str(self.padding) + ", "
-        inst_str += str(self.padding) + ", "
-        inst_str += str(self.strides[0]) + ", "
-        inst_str += str(self.strides[1]) + ", "
-        inst_str += "1, 1); \n"
-        return inst_str
+    def codegen(self):
+        return "tensorConvolution", [self.padding, self.padding, self.strides[0], self.strides[1]]
 
     def hpvm_codegen(self):
         return "__hpvm__tensor_convolution", [self.padding, self.padding, self.strides[0], self.strides[1]]
@@ -162,17 +124,8 @@ class MaxPool2DNode(DFGNode):
                     self.strides.append(stride)
 
 
-    def codegen(self, tensors):
-        mapped_input_name  = tensors[self.input[0]].get_mapped_name()
-        mapped_output_name = tensors[self.output[0]].get_mapped_name()
-        # tensorPooling(input, pool_type, pool_h, pool_w, v_pad, h_pad, v_stride, h_stride)
-        inst_str = "void* " + mapped_output_name + " = "
-        inst_str += "tensorPooling(" + mapped_input_name  + "," + \
-            self.pool_type + "," + str(self.pool_size[0]) + "," + str(self.pool_size[1])
-        inst_str += "," + str(self.padding) + "," + str(self.padding) + \
-            "," + str(self.strides[0]) + "," + str(self.strides[1])
-        inst_str += "); \n"
-        return inst_str
+    def codegen(self):
+        return "tensorPooling", [self.pool_type, *self.pool_size, self.padding, self.padding, *self.strides]
 
     def hpvm_codegen(self):
         return "__hpvm__tensor_pool_max", [*self.pool_size, self.padding, self.padding, *self.strides]
@@ -194,16 +147,8 @@ class AveragePool2DNode(DFGNode):
                 for stride in attr.ints:
                     self.strides.append(stride)
 
-    def codegen(self, tensors):
-        mapped_input_name  = tensors[self.input[0]].get_mapped_name()
-        mapped_output_name = tensors[self.output[0]].get_mapped_name()
-        inst_str = "void* " + mapped_output_name + " = "
-        inst_str += "tensorPooling(" + mapped_input_name  + "," + \
-            self.pool_type + "," + str(self.pool_size[0]) + "," + str(self.pool_size[1])
-        inst_str += "," + str(self.padding) + "," + str(self.padding) + \
-            "," + str(self.strides[0]) + "," + str(self.strides[1])
-        inst_str += "); \n"
-        return inst_str
+    def codegen(self):
+        return "tensorPooling", [self.pool_type, *self.pool_size, self.padding, self.padding, *self.strides]
 
     def hpvm_codegen(self):
         return "__hpvm__tensor_pool_avg", [*self.pool_size, self.padding, self.padding, *self.strides]
@@ -211,24 +156,16 @@ class AveragePool2DNode(DFGNode):
 
 class ReluNode(DFGNode):
 
-    def codegen(self, tensors):
-        mapped_input_name = tensors[self.input[0]].get_mapped_name()
-        mapped_output_name = tensors[self.output[0]].get_mapped_name()
-        inst_str = "void* " + mapped_output_name + " = "
-        inst_str += "tensorRelu(" + mapped_input_name + "); \n"
-        return inst_str
+    def codegen(self):
+        return "tensorRelu", []
 
     def hpvm_codegen(self):
         return "__hpvm__tensor_relu", []
 
 class TanhNode(DFGNode):
 
-    def codegen(self, tensors):
-        mapped_input_name = tensors[self.input[0]].get_mapped_name()
-        mapped_output_name = tensors[self.output[0]].get_mapped_name()
-        inst_str = "void* " + mapped_output_name + " = "
-        inst_str += "tensorTanh(" + mapped_input_name + "); \n"
-        return inst_str
+    def codegen(self):
+        return "tensorTanh", []
 
     def hpvm_codegen(self):
         return "__hpvm__tensor_tanh", []
@@ -243,31 +180,13 @@ class BatchNormalizationNode(DFGNode):
             if attr.name == "epsilon":
                 self.epsilon = str(attr.f)
 
-    def codegen(self, tensors):
-        mapped_input_name = tensors[self.input[0]].get_mapped_name()
-        mapped_output_name = tensors[self.output[0]].get_mapped_name()
-        inst_str = "void* " + mapped_output_name + " = "
-        inst_str += "tensorBatchNorm(" + mapped_input_name + ", "
-        inst_str += tensors[self.input[1]].get_mapped_name() + ", "
-        inst_str += tensors[self.input[2]].get_mapped_name() + ", "
-        inst_str += tensors[self.input[3]].get_mapped_name() + ", "
-        inst_str += tensors[self.input[4]].get_mapped_name() + ", "
-        inst_str += str(self.epsilon)
-        inst_str += "); \n"
-        return inst_str
+    def codegen(self):
+        return "tensorBatchNorm", [self.epsilon]
 
     def hpvm_codegen(self):
         return "__hpvm__tensor_batchnorm", [self.epsilon]
 
 
-class PadNode(DFGNode):
-    def codegen(self, tensors):
-        return ""
-
-class IdentityNode(DFGNode):
-    def codegen(self, tensors):
-        return ""
-
 class FlattenNode(DFGNode):
     def __init__(self, name: str, op_type: str, input, output):
         self.name = name
@@ -284,9 +203,6 @@ class FlattenNode(DFGNode):
         _, suffix = nodes[0].name.split('_')
         return cls(f'Flatten_{suffix}', 'Flatten', nodes[0].input, nodes[-1].output)
 
-    def codegen(self, tensors):
-        return ""
-
 class ZeroPadding2DNode(DFGNode):
     pass
 
@@ -295,3 +211,10 @@ class DepthwiseConv2DNode(DFGNode):
 
 class DenseNode(DFGNode):
     pass
+
+
+class PadNode(DFGNode):
+    pass
+
+class IdentityNode(DFGNode):
+    pass
\ No newline at end of file
diff --git a/hpvm/projects/onnx/frontend/template_tensor.cpp b/hpvm/projects/onnx/frontend/template_tensor.cpp
new file mode 100644
index 0000000000000000000000000000000000000000..ae3060836451069ed6aee9321d714db1b66c1723
--- /dev/null
+++ b/hpvm/projects/onnx/frontend/template_tensor.cpp
@@ -0,0 +1,50 @@
+#include <chrono>
+#include <fcntl.h>
+#include <stdio.h>
+#include <stdlib.h>
+#include <string.h>
+#include <sys/stat.h>
+#include <sys/types.h>
+#include <unistd.h>
+
+#include "../../tensor_runtime/include/tensor_runtime.h"
+#include "../include/utils.h"
+
+int main() {
+  std::string dir_prefix = "{{output_dir}}";
+  std::string input_path = dir_prefix + "input.bin";
+  std::string labels_path = dir_prefix + "labels.bin";
+{% for w in weights %}
+  std::string {{w.name}}_path = dir_prefix + std::string("{{w.filename}}");
+  void* {{w.name}} = readTrainedWeights({{w.name}}_path.c_str(), 0, {{w.shape|join(', ')}});
+{% endfor %}
+
+  llvm_hpvm_initTensorRt(0);
+  startMemTracking();
+
+  int test_input_size = {{input_shape[0]}};
+  int batch_size = {{input_shape[0]}};
+  // # FIXME: Ceiling for batch_count
+  int batch_count = test_input_size / batch_size;
+  float final_accuracy = 0.0;
+
+  for (int i = 0; i < batch_count; i++) {
+    int start = i * batch_size;
+    int end = (i + 1) * batch_size;
+    void *{{input}} = readInputBatch(input_path.c_str(), 0, start, end, {{input_shape|join(', ')}});
+
+{% for code in graph_code %}
+    auto {{code.output}} = {{code.function}}({{code.inputs|join(', ')}});
+{% endfor %}
+    
+    uint32_t* labels = readLabelsBatch3(labels_path.c_str(), start, end);
+    float accuracy = computeAccuracy3(labels, {{output}});
+    final_accuracy += accuracy;
+    freeBatchMemory();
+  }
+
+  final_accuracy = final_accuracy / batch_count;
+  dumpFinalAccuracy(final_accuracy);
+  llvm_hpvm_cleanupTensorRt();
+  return 0;
+}
\ No newline at end of file