diff --git a/hpvm/projects/onnx/frontend/graph_builder.py b/hpvm/projects/onnx/frontend/graph_builder.py
index f5cc3fa158bead127eebaac50afd6f9ea783ae4e..192e6afbb8fdb108ab51eace5dd4269fd1be1064 100644
--- a/hpvm/projects/onnx/frontend/graph_builder.py
+++ b/hpvm/projects/onnx/frontend/graph_builder.py
@@ -1,6 +1,7 @@
 import sys
 from onnx import numpy_helper
 from tensor import InputTensor, WeightTensor
+from graph_ir import *
 
 support_onnx_ops = {"DepthwiseConv" : [2],
                "Conv" : [2], # only 2d supported here
@@ -162,98 +163,37 @@ class DFG(object):
     def __init__(self, graph, tensors):
         self.graph = graph
         self.tensors = tensors
-
-    def hasSingleInput(self, layer):
-        layer_name = layer.__class__.__name__
-        return layer_name in self.singleInputLayers
-
-    def hasMultipleInputs(self, layer):
-        layer_name = layer.__class__.__name__
-        return layer_name in self.multiInputLayers
-
-    def add_dfg_edge(self, inbound_node_name, dfg_node):
-        inbound_node_name = inbound_node_name.split(":")[0]
-        inbound_node_name = inbound_node_name.split("/")[0]
-        if inbound_node_name in self.node_map:
-            inbound_node = self.node_map[inbound_node_name]
-            print(inbound_node_name, " found!")
-            inbound_node.add_output(dfg_node)
-            dfg_node.add_input(inbound_node)
-        else:
-            print("--inbound node NOT FOUND!")
-
-    def add_to_graph(self, layer):
-        dfg_node = DFGNode(layer)
-        if not self.root_set:
-            self.root_node = dfg_node
-            self.root_set = True  # DFG root node is now set
-
-        if self.hasMultipleInputs(layer):
-            for j in range(len(layer.input)):
-                print(type(layer.input[j]))
-                print(layer.input[j].op.name)
-                self.add_dfg_edge(layer.input[j].op.name, dfg_node)
-        else:
-            print(layer.input.name)
-            self.add_dfg_edge(layer.input.name, dfg_node)
-        # Adding DFG node to name mapping
-        self.node_map[layer.name] = dfg_node
-
-    # Check if all predecessor nodes have been visited thus far - reverse
-    # postorder traversal
-
-    def predVisited(self, cur_node, visited_nodes):
-        for input_node in cur_node.inputs:
-            if input_node.layer_name not in visited_nodes:
-                return False
-        # All predecessors are visited
-        return True
-
-    def traverseNode(self, cur_node, visited_nodes):
-        # Skip visited nodes
-        if cur_node.layer_name in visited_nodes:
-            return
-
-        if self.predVisited(cur_node, visited_nodes):
-            print(cur_node.layer_type)
-            print(cur_node.layer_name)
-            visited_nodes[cur_node.layer_name] = True
-
-            # Invoking traversal on outbound nodes
-            for output_node in cur_node.outputs:
-                self.traverseNode(output_node, visited_nodes)
-
-            # NOTE: Assuming that no outbound edges implies the last node in
-            # the graph
-            if len(cur_node.outputs) == 0:
-                self.last_node = cur_node
-
-    # Build and  Print the DFG in reverse postorder
+        self.nodes = list()
 
     def buildDFG(self):
         print("\n\n ****** Traversing and Printing DFG ******* \n\n")
-        visited_nodes = {}
-        # Starting traversal at the DFG root node
-        self.traverseNode(self.root_node, visited_nodes)
+        for node in self.graph.node:
+            self.nodes.append(self.emit_node(node))
 
     # This should be the place where partial evaluation happens
-    def emitNode(self, layer):
-        if layer.op_type == "Conv":
-            return Conv2DNode()
-        elif layer.op_type == "Tanh":
-            pass
-        elif layer.op_type == "MaxPool":
-            pass
-        elif layer.op_type == "Flatten":
-            pass
-        elif layer.op_type == "MatMul":
-            pass
-        elif layer.op_type == "Add":
-            pass
-        elif layer.op_type == "SoftMax":
-            pass
-        elif layer.op_type == "Identity":
-            pass
+    def emit_node(self, onnx_node):
+        if onnx_node.op_type == "Conv":
+            return Conv2DNode(onnx_node)
+        elif onnx_node.op_type == "MaxPool":
+            return MaxPool2DNode(onnx_node)
+        elif onnx_node.op_type == "AveragePool":
+            return AveragePool2DNode(onnx_node)
+        elif onnx_node.op_type == "MatMul":
+            return MatMulNode(onnx_node)
+        elif onnx_node.op_type == "Add":
+            return AddNode(onnx_node)
+        elif onnx_node.op_type == "Softmax":
+            return SoftMaxNode(onnx_node)
+        elif onnx_node.op_type == "Relu":
+            return ReluNode(onnx_node)
+        elif onnx_node.op_type == "BatchNormalization":
+            return BatchNormalizationNode(onnx_node)
+        elif onnx_node.op_type == "Pad":
+            return PadNode(onnx_node)
+        elif onnx_node.op_type == "Identity":
+            return IdentityNode(onnx_node)
+        elif onnx_node.op_type == "Flatten":
+            return FlattenNode(onnx_node)
         else:
             raise ValueError("Unsupported operator type!")
             sys.exit("Unsupported operator type!")
diff --git a/hpvm/projects/onnx/frontend/graph_codegen.py b/hpvm/projects/onnx/frontend/graph_codegen.py
index 5524c276a86f10a125a9f502ce3a66772c619bb2..9489fc19808a53ea269dfddeb17339e15a749acb 100644
--- a/hpvm/projects/onnx/frontend/graph_codegen.py
+++ b/hpvm/projects/onnx/frontend/graph_codegen.py
@@ -6,18 +6,18 @@ from graph_builder import *
 #from graph_ir import *
 from tensor import *
 
+skip_layer = ["Identity", "Flatten", "Pad"]
 
 class GraphCodeGen(object):
     def __init__(self, DFG, weights_dir, test_data=None, test_labels=None):
         self.program_str = ""
         self.graph = DFG.graph
         self.tensors = DFG.tensors
-        self.nodes = self.graph.node
+        self.nodes = DFG.nodes
         self.var_cnt = 0
         self.weights_dir = weights_dir
         self.test_data = test_data
-        self.test_labels =test_labels
-        self.skip_layer = ["Identity", "Flatten", "Pad"]
+        self.test_labels = test_labels
 
     ################################################
     # Aux functions
@@ -131,7 +131,7 @@ class GraphCodeGen(object):
           inst_str += self.tensors[cur_node.input[4]].get_mapped_name() + ", "
           inst_str += str(epsilon)
           inst_str += "); \n"
-        elif cur_node.op_type in self.skip_layer:
+        elif cur_node.op_type in skip_layer:
             pass
         else:
             raise ValueError("Not supported op type:" + cur_node.op_type + "! \n")
@@ -168,6 +168,22 @@ class GraphCodeGen(object):
 
     def emit_graph(self):
         for node in self.nodes:
+            # check if all inputs of this node is mapped
+            for i in cur_node.input:
+                self.tensors[i].get_mapped_name() 
+            # set var name for output node
+            if len(cur_node.output) > 1:
+                raise ValueError("Output number for a single layer larger than 1!")
+            if cur_node.op_type in self.skip_layer:
+                mapped_output_name = self.get_last_var()
+            else:
+                mapped_output_name = self.get_new_var()
+            output_name = cur_node.output[0]
+            self.tensors[output_name].set_mapped_name(mapped_output_name)
+            self.program_str += node.codegen()
+
+    def emit_graph2(self):
+        for node in self.graph.nodes:
             #pass
             self.program_str += self.emit_node_call(node)
 
@@ -201,10 +217,10 @@ class GraphCodeGen(object):
 
     def emit_batch_loop(self, x_test=None):
         # FIXME: Dimensions from test data
-        N = x_test.shape[0]
-        C = x_test.shape[1]
-        H = x_test.shape[2]
-        W = x_test.shape[3]
+        N = 1#x_test.shape[0]
+        C = 1#x_test.shape[1]
+        H = 1#x_test.shape[2]
+        W = 1#x_test.shape[3]
 
         loop_str = ""
         loop_str += "\nstartMemTracking(); \n\n"
@@ -258,7 +274,7 @@ class GraphCodeGen(object):
         self.emit_header()
         self.emit_weights()
         self.emit_batch_loop()
-        self.emit_graph()
+        self.emit_graph2()
         self.emit_batch_loop_end()
         self.emit_footer()
         # Write the program to source/disk
diff --git a/hpvm/projects/onnx/frontend/graph_ir.py b/hpvm/projects/onnx/frontend/graph_ir.py
index 71b577ff7fccbcea71fd89b4aeb2d26d407ddb6f..52c7569a354a7b69389a958083ee5e30ec47596c 100644
--- a/hpvm/projects/onnx/frontend/graph_ir.py
+++ b/hpvm/projects/onnx/frontend/graph_ir.py
@@ -1,18 +1,16 @@
 ################################################
 # Top Level DFGNode interface
 ################################################
-class DFGNode(object):
-    def add_output(self, output_node):
-        self.outputs.append(output_node)
 
-    def add_input(self, input_node):
-        self.inputs.append(input_node)
 
-    def __init__(self, layer):
-        self.inputs = list()
-        self.outputs = list()
-        self.name = layer.name
-        self.op_type = layer.op_type
+class DFGNode(object):
+    def __init__(self, onnx_node):
+        self.node = onnx_node
+        self.name = onnx_node.name
+        self.op_type = onnx_node.op_type
+
+    def codegen(self, tensors):
+        pass
 
 
 '''
@@ -21,6 +19,8 @@ e.g. HardSigmoid, LeakyRelu, PRelu, Pow, Reciprocal,
 Relu, Selu, Sigmoid, Softplus, Sqrt, ThresholdedRelu,
 Abs, Ceil, Elu, Floor, Neg
 '''
+
+
 class ActivationNode(DFGNode):
     pass
 
@@ -31,6 +31,8 @@ In other words, they are logical comparison operators
 e.g. And, Equal, Greater, GreaterOrEqual, Less, LessOrEqual,
 Or, Xor
 '''
+
+
 class LogicalOpNode(DFGNode):
     pass
 
@@ -43,110 +45,200 @@ class AddNode(DFGNode):
     def __init__(self, layer):
         DFGNode.__init__(self, layer)
 
-    def codegen(self):
-        pass
+    def codegen(self, tensors):
+        left_input = self.tensors[cur_node.input[0]].get_mapped_name()
+        right_input = self.tensors[cur_node.input[1]].get_mapped_name()
+        inst_str += "void* " + mapped_output_name + " = "
+        inst_str += "tensorAdd(" + left_input + ", " + right_input + "); \n"
 
 
 class MatMulNode(DFGNode):
     def __init__(self, layer):
         DFGNode.__init__(self, layer)
 
-    def codegen(self):
-        pass
+    def codegen(self, tensors):
+        left_input = self.tensors[cur_node.input[0]].get_mapped_name()
+        right_input = self.tensors[cur_node.input[1]].get_mapped_name()
+        inst_str += "void* " + mapped_output_name + " = "
+        inst_str += "tensorGemmGPU(" + left_input + \
+            ", " + right_input + "); \n"
 
 
 class SoftMaxNode(DFGNode):
     def __init__(self, layer):
         DFGNode.__init__(self, layer)
 
-    def codegen(self):
-        pass
+    def codegen(self, tensors):
+        mapped_input_name = self.tensors[cur_node.input[0]].get_mapped_name()
+        inst_str += "void* " + mapped_output_name + " = "
+        inst_str += "tensorSoftmax(" + mapped_input_name + "); \n"
 
 
 class Conv2DNode(DFGNode):
     def __init__(self, layer):
         DFGNode.__init__(self, layer)
-        self.weights = layer.get_weights()[0]
-        print("\t", self.weights.shape)
-        self.use_bias = layer.use_bias
-        self.padding = layer.padding
-        self.strides = layer.strides
-        print("\t", self.strides)
-        print("\tPadding = ", self.padding)
-
-    def codegen(self):
-        pass
 
+    def codegen(self, tensors):
+        cur_node = self.node
+        input_var_name = tensors[cur_node.input[0]].get_mapped_name()
+        weight = cur_node.input[1]
+        strides = list()
+        padding = 0
+        for attr in cur_node.attribute:
+            if attr.name == "pads":
+                padding = attr.ints[0]
+            elif attr.name == "strides":
+                for stride in attr.ints:
+                    strides.append(stride)
+
+        inst_str += "void* " + mapped_output_name + " = "
+        inst_str += "tensorConvolution(" + input_var_name + ", "
+        inst_str += tensors[cur_node.input[1]].get_mapped_name() + ", "
+        inst_str += str(padding) + ", "
+        inst_str += str(padding) + ", "
+        inst_str += str(strides[0]) + ", "
+        inst_str += str(strides[1]) + ", "
+        inst_str += "1, 1); \n"
+
+        # check if has bias add (Optional)
+        # in ONNX it is only in Conv
+        # in Keras bias add could exist in Dense
+        if len(cur_node.input) == 3:
+            mapped_output_name2 = self.get_new_var()
+            inst_str += "void* " + mapped_output_name2 + " = "
+            inst_str += "tensorAdd(" + mapped_output_name + ", "
+            inst_str += self.tensors[cur_node.input[2]].get_mapped_name() + ""
+            inst_str += "); \n"
+            self.tensors[output_name].set_mapped_name(mapped_output_name2)
 
-class DepthwiseConv2DNode(DFGNode):
+
+class MaxPool2DNode(DFGNode):
     def __init__(self, layer):
         DFGNode.__init__(self, layer)
-        self.weights = layer.get_weights()[0]
-        print("\t", self.weights.shape)
-        self.use_bias = layer.use_bias
-        self.padding = layer.padding
-        self.strides = layer.strides
-        print("\t", self.strides)
-        print("\tPadding = ", self.padding)
-
-    def codegen(self):
-        pass
 
+    def codegen(self, tensors):
+        input_var_name = self.tensors[cur_node.input[0]].get_mapped_name()
+        strides = list()
+        pool_size = list()
+        for attr in cur_node.attribute:
+            if attr.name == "kernel_shape":
+                for pool in attr.ints:
+                    pool_size.append(pool)
+            elif attr.name == "strides":
+                for stride in attr.ints:
+                    strides.append(stride)
+        # FIXME: Non-same padding is *NOT* currently supported
+        padding = 0
+        pool_type = "0"
+        # tensorPooling(input, pool_type, pool_h, pool_w, v_pad, h_pad, v_stride, h_stride)
+        inst_str += "void* " + mapped_output_name + " = "
+        inst_str += "tensorPooling(" + input_var_name + "," + \
+            pool_type + "," + str(pool_size[0]) + "," + str(pool_size[1])
+        inst_str += "," + str(padding) + "," + str(padding) + \
+            "," + str(strides[0]) + "," + str(strides[1])
+        inst_str += "); \n"
+
+
+class AveragePool2DNode(DFGNode):
+    def __init__(self, layer):
+        DFGNode.__init__(self, layer)
 
-class DenseNode(DFGNode):
+    def codegen(self, tensors):
+        input_var_name = self.tensors[cur_node.input[0]].get_mapped_name()
+        strides = list()
+        pool_size = list()
+        for attr in cur_node.attribute:
+            if attr.name == "kernel_shape":
+                for pool in attr.ints:
+                    pool_size.append(pool)
+            elif attr.name == "strides":
+                for stride in attr.ints:
+                    strides.append(stride)
+        # FIXME: Non-same padding is *NOT* currently supported
+        padding = 0
+        pool_type = "1"
+        # tensorPooling(input, pool_type, pool_h, pool_w, v_pad, h_pad, v_stride, h_stride)
+        inst_str += "void* " + mapped_output_name + " = "
+        inst_str += "tensorPooling(" + input_var_name + "," + \
+            pool_type + "," + str(pool_size[0]) + "," + str(pool_size[1])
+        inst_str += "," + str(padding) + "," + str(padding) + \
+            "," + str(strides[0]) + "," + str(strides[1])
+        inst_str += "); \n"
+
+
+class ReluNode(DFGNode):
     def __init__(self, layer):
         DFGNode.__init__(self, layer)
-        self.weights = layer.get_weights()[0]
-        print("\t", self.weights.shape)
-        self.use_bias = layer.use_bias
 
-    def codegen(self):
+    def codegen(self, tensors):
+        mapped_input_name = self.tensors[cur_node.input[0]].get_mapped_name()
+        inst_str += "void* " + mapped_output_name + " = "
+        inst_str += "tensorRelu(" + mapped_input_name + "); \n"
+
+
+class BatchNormalizationNode(DFGNode):
+    def __init__(self, layer):
+        DFGNode.__init__(self, layer)
+
+    def codegen(self, tensors):
+        mapped_input_name = self.tensors[cur_node.input[0]].get_mapped_name()
+        epsilon = ""
+        for attr in cur_node.attribute:
+            if attr.name == "epsilon":
+                epsilon = str(attr.f)
+        inst_str += "void* " + mapped_output_name + " = "
+        inst_str += "tensorBatchNorm(" + mapped_input_name + ", "
+        inst_str += self.tensors[cur_node.input[1]].get_mapped_name() + ", "
+        inst_str += self.tensors[cur_node.input[2]].get_mapped_name() + ", "
+        inst_str += self.tensors[cur_node.input[3]].get_mapped_name() + ", "
+        inst_str += self.tensors[cur_node.input[4]].get_mapped_name() + ", "
+        inst_str += str(epsilon)
+        inst_str += "); \n"
+
+
+class PadNode(DFGNode):
+    def __init__(self, layer):
+        DFGNode.__init__(self, layer)
+
+    def codegen(self, tensors):
         pass
 
 
-class MaxPool2DNode(DFGNode):
+class IdentityNode(DFGNode):
     def __init__(self, layer):
         DFGNode.__init__(self, layer)
-        self.pool_size = layer.pool_size
-        self.strides = layer.strides
-        print("\t pool_size = ", self.pool_size)
-        print("\t strides = ", self.strides)
 
-    def codegen(self):
+    def codegen(self, tensors):
         pass
 
 
-class AveragePooling2DNode(DFGNode):
+class FlattenNode(DFGNode):
     def __init__(self, layer):
         DFGNode.__init__(self, layer)
-        self.pool_size = layer.pool_size
-        self.strides = layer.strides
-        print("\t pool_size = ", self.pool_size)
-        print("\t strides = ", self.strides)
 
-    def codegen(self):
+    def codegen(self, tensors):
         pass
 
 
 class ZeroPadding2DNode(DFGNode):
     def __init__(self, layer):
         DFGNode.__init__(self, layer)
-        print("***ZeroPaddding \n")
-        self.padding = layer.padding
-        print("padding = ", self.padding)
 
-    def codegen(self):
+    def codegen(self, tensors):
         pass
 
 
-class BatchNormalizationNode(DFGNode):
+class DepthwiseConv2DNode(DFGNode):
+    def __init__(self, layer):
+        DFGNode.__init__(self, layer)
+
+    def codegen(self, tensors):
+        pass
+
+
+class DenseNode(DFGNode):
     def __init__(self, layer):
         DFGNode.__init__(self, layer)
-        self.epsilon = layer.epsilon
-        self.beta = layer.beta
-        self.gamma = layer.gamma
-        self.moving_mean = layer.moving_mean
-        self.moving_variance = layer.moving_variance
 
-    def codegen(self):
+    def codegen(self, tensors):
         pass
diff --git a/hpvm/projects/onnx/frontend/main.py b/hpvm/projects/onnx/frontend/main.py
index 5a76e80de1067f019747a705dbacdf96291c7fb0..3a9344015b7082fd8e59f92fc6e4faf0e80a0b04 100644
--- a/hpvm/projects/onnx/frontend/main.py
+++ b/hpvm/projects/onnx/frontend/main.py
@@ -5,6 +5,9 @@ import onnx
 import glob
 #from onnxruntime.backend.backend import OnnxRuntimeBackend as backend
 
+onnx_file_dir = "../models/keras/lenet.onnx"
+src_emit_dir = "./test_src"
+
 def check_version(model, new_version):
     try:
         opset = model.opset_import[0].version if model.opset_import else 1
@@ -29,7 +32,9 @@ def check_version(model, new_version):
     return model
 
 def compile(model):
-    weights_dir = './test_src'
+    # TODO: make this in constant
+    # make a cmd option, default value -> constant
+    weights_dir = src_emit_dir
     opset_version_default = 11
     # test_data_dir = '../models/mnist/test_data_set_0'
     # converted_model = convert_version(model)
@@ -41,7 +46,8 @@ def compile(model):
     gCodegen.compile()
 
 def main():
-    model = onnx.load('../models/keras/lenet.onnx')
+    # TODO: Put it in args
+    model = onnx.load(onnx_file_dir)
     # model = onnx.load('../models/keras/vgg16_cifar10.onnx')
     compile(model)