diff --git a/hpvm/projects/onnx/frontend/approx_codegen.py b/hpvm/projects/onnx/frontend/approx_codegen.py
index cafee7a2536b1a4ff1ba2d98b02c930a9189eb22..5fd6f9490a93bbe00558f246ebfbfcc6dbf46cad 100644
--- a/hpvm/projects/onnx/frontend/approx_codegen.py
+++ b/hpvm/projects/onnx/frontend/approx_codegen.py
@@ -101,7 +101,7 @@ class GraphCodeGen:
         f.write(source)
         f.close()
 
-    def codegen(self, model, weights_dir, test_data):
+    def compile(self, model, weights_dir, test_data):
         self.emitHeaders()
         self.emitRoot()
         self.emitMainFunc(test_data)
diff --git a/hpvm/projects/onnx/frontend/common.py b/hpvm/projects/onnx/frontend/common.py
index 9fa037ef13d783ca95289827c104b8c58f63d091..26e1aa715e4e504820f7bfd18e0bbaa4e2ce3767 100644
--- a/hpvm/projects/onnx/frontend/common.py
+++ b/hpvm/projects/onnx/frontend/common.py
@@ -31,6 +31,29 @@ class InputTensor(Tensor):
 class WeightTensor(Tensor):
 	def __init__(self, weight_proto):
 		Tensor.__init__(self, weight_proto)
+		self.shape = list()
 		self.input_data = numpy_helper.to_array(weight_proto)#.reshape(tuple(input_proto.dims))
-
+		print(self.input_data.shape)
+		if len(self.input_data.shape) == 1:
+			self.shape.append(1)
+			self.shape.append(self.input_data.shape[0])
+			self.shape.append(1)
+			self.shape.append(1)
+		elif len(self.input_data.shape) == 2:
+			self.shape.append(1)
+			self.shape.append(1)
+			self.shape.append(self.input_data.shape[0])
+			self.shape.append(self.input_data.shape[1])
+		elif len(self.input_data.shape) == 4:
+			self.shape.append(self.input_data.shape[0])
+			self.shape.append(self.input_data.shape[1])
+			self.shape.append(self.input_data.shape[2])
+			self.shape.append(self.input_data.shape[3])
+		else:
+			print(weight_proto.name)
+			self.shape.append(1)
+			self.shape.append(1)
+			self.shape.append(1)
+			self.shape.append(1)
+			#raise ValueError("Dimensions of weight not equals to 1,2 or 4")
 
diff --git a/hpvm/projects/onnx/frontend/graph_builder.py b/hpvm/projects/onnx/frontend/graph_builder.py
index 3ded2f68e1ab6550ddf365d40f3e06eb8f616077..be281631fa5dd73f6691b7cafdecbce36498f3bb 100644
--- a/hpvm/projects/onnx/frontend/graph_builder.py
+++ b/hpvm/projects/onnx/frontend/graph_builder.py
@@ -3,7 +3,7 @@ from onnx import numpy_helper
 from graph_ir import Node
 from common import InputTensor, WeightTensor
 
-support_onnx_ops = {"DepthwiseConv2D" : None,
+support_onnx_ops = {"DepthwiseConv" : [2],
                "Conv" : [2], # only 2d supported here
                "MatMul" : None,
                "MaxPool": [2], # only 2d supported here
@@ -19,12 +19,12 @@ support_onnx_ops = {"DepthwiseConv2D" : None,
                "Tanh": None}
 
 class GraphBuilder(object):
-    def __init__(self, model, shape, dtype, opset, weight_dir):
+    def __init__(self, model, shape, dtype, weight_dir):
         self._check_model(model)
+        self._check_ops(model)
         self.model = model
         self.dtype = dtype
         self.graph = model.graph
-        self.opset = opset
         self.weight_dir = weight_dir
         self.shape = shape if shape else self._build_shape()
         self.tensors = dict()
@@ -48,6 +48,19 @@ class GraphBuilder(object):
             raise ImportError(
                 "Unable to import onnx.checker which is required {}".format(e))
 
+    def _check_ops(self, model):
+        unsupport = dict()
+        for node in model.graph.node:
+            if node.op_type not in support_onnx_ops:
+                if node.op_type not in unsupport:
+                    unsupport[node.op_type] = 1
+                else:
+                    unsupport[node.op_type] += 1
+        if len(unsupport) != 0:
+            print(sorted(unsupport.items(), key=lambda x: x[1], reverse=True))
+            raise ValueError(
+                "Above operator(s) not currently supported! Compilation Aborted.")
+
     def _build_shape(self):
         shape = {}
         for input in self.graph.input:
@@ -85,6 +98,7 @@ class GraphBuilder(object):
 
     def _support_check(self, node):
         op_name = node.op_type
+        #print(op_name)
         if op_name not in support_onnx_ops:
             return False
         else:
@@ -95,6 +109,7 @@ class GraphBuilder(object):
                 for attr in node.attribute:
                     # partially evaluate the kernel shape
                     if attr.name == 'kernel_shape':
+                        # TODO: not assume all kernel shape is in INTS
                         return len(attr.ints) in support_onnx_ops[op_name]
                 return False
                
@@ -109,12 +124,17 @@ class GraphBuilder(object):
 
     def build_graph(self):
         # parse weight
+        weight_cnt = 0
         for weight_tensor in self.graph.initializer:
             self.tensors[weight_tensor.name] = WeightTensor(weight_tensor)
+            self.tensors[weight_tensor.name].set_mapped_name("weight_" + str(weight_cnt))
+            weight_cnt += 1
         # parse input
         for i in self.graph.input:
             if i.name not in self.tensors:
                 self.tensors[i.name] = InputTensor(i.name)
+                # FIXME: This input name is hardcoded
+                self.tensors[i.name].set_mapped_name("input")
         # parse intermediate tensor
         for node in self.graph.node:
             op_name = node.op_type
diff --git a/hpvm/projects/onnx/frontend/graph_codegen.py b/hpvm/projects/onnx/frontend/graph_codegen.py
index 7232cbd22835c54cd7b1429a1908680e66397cf3..36d96f0d19456f2147d3bc1e5d806c959abd1ecf 100644
--- a/hpvm/projects/onnx/frontend/graph_codegen.py
+++ b/hpvm/projects/onnx/frontend/graph_codegen.py
@@ -8,49 +8,183 @@ from common import *
 
 
 class GraphCodeGen(object):
-    def __init__(self, DFG, weights_dir, test_data, test_labels):
+    def __init__(self, DFG, weights_dir, test_data=None, test_labels=None):
         self.program_str = ""
         self.graph = DFG.graph
         self.tensors = DFG.tensors
+        self.nodes = self.graph.node
         self.var_cnt = 0
         self.weights_dir = weights_dir
         self.test_data = test_data
         self.test_labels =test_labels
+        self.skip_layer = ["Identity", "Flatten", "Pad"]
 
     ################################################
     # Aux functions
     ################################################
-    def get_var_cnt(self):
-        cnt = self.var_cnt
+    def get_last_var(self):
+        return "var_" + str(self.var_cnt)
+
+    def get_new_var(self):
         self.var_cnt = self.var_cnt + 1
-        return cnt
+        return "var_" + str(self.var_cnt)
+
+    def emit_node_call(self, cur_node):
+        inst_str = ""
+        # check if all inputs of this node is mapped
+        #for i in cur_node.input:
+            #tensor[i].get_mapped_name() 
+        # set var name for output node
+        if len(cur_node.output) > 1:
+            raise ValueError("Output number for a single layer larger than 1!")
+        if cur_node.op_type in self.skip_layer:
+            mapped_output_name = self.get_last_var()
+        else:
+            mapped_output_name = self.get_new_var()
+        output_name = cur_node.output[0]
+        self.tensors[output_name].set_mapped_name(mapped_output_name)
+        #print(cur_node)
+        if cur_node.op_type == "Conv":# or cur_node.op_type == "DepthwiseConv":
+          input_var_name = self.tensors[cur_node.input[0]].get_mapped_name()
+          weight = cur_node.input[1]
+          strides = list()
+          padding = 0
+          for attr in cur_node.attribute:
+            if attr.name == "pads":
+                padding = attr.ints[0]
+            elif attr.name == "strides":
+                for stride in attr.ints:
+                    strides.append(stride)
+
+          inst_str += "void* " + mapped_output_name + " = "
+          inst_str += "tensorConvolution(" + input_var_name + ", "
+          inst_str += self.tensors[cur_node.input[1]].get_mapped_name() + ", "
+          inst_str += str(padding) + ", "
+          inst_str += str(padding) + ", "
+          inst_str += str(strides[0]) + ", "
+          inst_str += str(strides[1]) + ", "
+          inst_str += "1, 1); \n"
+
+          # check if has bias add (Optional)
+          # in ONNX it is only in Conv
+          # in Keras bias add could exist in Dense
+          if len(cur_node.input) == 3:
+            mapped_output_name2 = self.get_new_var()
+            inst_str += "void* " + mapped_output_name2 + " = "
+            inst_str += "tensorAdd(" + mapped_output_name + ", "
+            inst_str += self.tensors[cur_node.input[2]].get_mapped_name() + ""
+            inst_str += "); \n"
+            self.tensors[output_name].set_mapped_name(mapped_output_name2)
+        elif cur_node.op_type == "MaxPool" or cur_node.op_type == "AveragePool":  
+          input_var_name = self.tensors[cur_node.input[0]].get_mapped_name()
+          strides = list()
+          pool_size = list()
+          for attr in cur_node.attribute:
+            if attr.name == "kernel_shape":
+                for pool in attr.ints:
+                    pool_size.append(pool)
+            elif attr.name == "strides":
+                for stride in attr.ints:
+                    strides.append(stride)
+          # FIXME: Non-same padding is *NOT* currently supported
+          padding = 0
+          pool_type = 0
+          if cur_node.op_type == "MaxPool":
+            pool_type = "0"   
+          elif cur_node.op_type == "AveragePool":
+            pool_type = "1"
+          # tensorPooling(input, pool_type, pool_h, pool_w, v_pad, h_pad, v_stride, h_stride)
+          inst_str += "void* " + mapped_output_name + " = "
+          inst_str += "tensorPooling(" + input_var_name + "," + pool_type + "," + str(pool_size[0]) + "," + str(pool_size[1]) 
+          inst_str +=  "," + str(padding) + "," + str(padding) + "," + str(strides[0]) + "," + str(strides[1])
+          inst_str += "); \n"
+          # self.program_str += inst_str
+        elif cur_node.op_type == "MatMul":
+            left_input = self.tensors[cur_node.input[0]].get_mapped_name()
+            right_input = self.tensors[cur_node.input[1]].get_mapped_name()
+            inst_str += "void* " + mapped_output_name + " = "
+            inst_str += "tensorGemmGPU(" + left_input + ", " + right_input + "); \n"
+        elif cur_node.op_type == "Add":
+            left_input = self.tensors[cur_node.input[0]].get_mapped_name()
+            right_input = self.tensors[cur_node.input[1]].get_mapped_name()
+            inst_str += "void* " + mapped_output_name + " = "
+            inst_str += "tensorAdd(" + left_input + ", " + right_input + "); \n"
+        elif cur_node.op_type == "Softmax":
+            mapped_input_name = self.tensors[cur_node.input[0]].get_mapped_name()
+            inst_str += "void* " + mapped_output_name + " = "
+            inst_str += "tensorSoftmax(" + mapped_input_name + "); \n"
+        elif cur_node.op_type == "Relu":
+            mapped_input_name = self.tensors[cur_node.input[0]].get_mapped_name()
+            inst_str += "void* " + mapped_output_name + " = "
+            inst_str += "tensorRelu(" + mapped_input_name + "); \n"
+        elif cur_node.op_type == "BatchNormalization":
+          mapped_input_name = self.tensors[cur_node.input[0]].get_mapped_name()
+          epsilon = ""
+          for attr in cur_node.attribute:
+            if attr.name == "epsilon":
+                epsilon = str(attr.f)
+          inst_str += "void* " + mapped_output_name + " = "
+          inst_str += "tensorBatchNorm(" + mapped_input_name + ", "
+          inst_str += self.tensors[cur_node.input[1]].get_mapped_name() + ", "
+          inst_str += self.tensors[cur_node.input[2]].get_mapped_name() + ", "
+          inst_str += self.tensors[cur_node.input[3]].get_mapped_name() + ", "
+          inst_str += self.tensors[cur_node.input[4]].get_mapped_name() + ", "
+          inst_str += str(epsilon)
+          inst_str += "); \n"
+        elif cur_node.op_type in self.skip_layer:
+            pass
+        else:
+            raise ValueError("Not supported op type:" + cur_node.op_type + "! \n")
+        return inst_str
+
+    def not_used(self):
+        if cur_node.op_type == "BatchNormalization":
+          input_var_name = self.getSingleInputName(cur_node)
+
+          inst_str += "void* " + out_var_name1 + " = "
+          inst_str += "tensorBatchNorm(" + input_var_name + ", "
+          inst_str += cur_node.layer_name + "_gamma, "
+          inst_str += cur_node.layer_name + "_beta, "
+          inst_str += cur_node.layer_name + "_mean, "
+          inst_str += cur_node.layer_name + "_variance, "
+          inst_str += str(cur_node.epsilon)
+          inst_str += "); \n"
+          
+          # self.program_str += inst_str
 
     ################################################
     # Emit functions for code generation
     ################################################
 
     def emit_weights(self):
-        weight_str += "std::string dir_prefix = std::string(\"" + str(self.weights_dir) + "\");"
-        weight_str += "std::string input_path =  dir_prefix + std::string(\"input.bin\");"
-        weight_str += "std::string labels_path =  dir_prefix + std::string(\"labels.bin\");"
+        weights_str = ""
+        weights_str += "std::string dir_prefix = std::string(\"" + str(self.weights_dir) + "\");\n"
+        weights_str += "std::string input_path =  dir_prefix + std::string(\"input.bin\");\n"
+        weights_str += "std::string labels_path =  dir_prefix + std::string(\"labels.bin\");\n"
         for tensor in self.tensors.values():
             if isinstance(tensor, WeightTensor):
-                print(tensor.name)
-
-    def traverse_graph(self, cur_node, visited):
-        if cur_node in visited:
-            return
+                weights_str += self.emit_single_weight(tensor)
+        self.program_str += weights_str
 
-        if dfg.predVisited(cur_node, visited):
-            visited_nodes[cur_node.layer_name] = True
-            self.program_str += cur_node.codegen()
-            for output_node in cur_node.outputs:
-                self.traverse_graph(dfg, output_node, visited)
+    def emit_single_weight(self, tensor):
+        N = tensor.shape[0]
+        C = tensor.shape[1]
+        H = tensor.shape[2]
+        W = tensor.shape[3]
+        mapped_name = tensor.get_mapped_name()
+        file_path = mapped_name + "_path" 
+        unique_file_name = file_path + ".bin"
+        weight_str = "std::string " + file_path + " = " + " dir_prefix + std::string(\""
+        weight_str += unique_file_name + "\"); \n"
+        weight_str += "void* " + mapped_name + " = " + " readTrainedWeights("
+        weight_str += file_path + ".c_str(), 0," + str(N) + "," + str(C) + "," + str(H) + "," + str(W)
+        weight_str += "); \n"
+        return weight_str
 
     def emit_graph(self):
-        self.build_graph()
-        visited_nodes = {}
-        self.traverse_graph(self.dfg.root, visited)
+        for node in self.nodes:
+            #pass
+            self.program_str += self.emit_node_call(node)
 
     def emit_header(self):
         headers = "\n#include <stdio.h> \n"
@@ -60,6 +194,7 @@ class GraphCodeGen(object):
         headers += "#include <sys/types.h> \n"
         headers += "#include <sys/stat.h> \n"
         headers += "#include <string.h> \n"
+        headers += "#include <chrono> \n"
         headers += "#include \"../../tensor_runtime/include/tensor_runtime.h\" \n"
         headers += "#include \"../include/utils.h\" \n\n"
 
@@ -69,21 +204,22 @@ class GraphCodeGen(object):
         self.program_str += main_func
         self.program_str += initialization
 
-    def emit_footer(self, test_data):
-        if test_data is not None and self.dfg.last_node is not None:
-            last_node = self.dfg.last_node
-            output_var = self.output_map[last_node.layer_name]
+    def emit_footer(self, test_data=None):
+        #if test_data is not None and self.dfg.last_node is not None:
+            #last_node = self.dfg.last_node
+            #output_var = self.output_map[last_node.layer_name]
 
         destructors = "\nllvm_hpvm_cleanupTensorRt(); \n"
         end_main = "\nreturn 0; \n\n}\n"
         self.program_str += destructors
         self.program_str += end_main
 
-    def emit_batch_loop(self, x_test):
-        N = x_test.shape[0]
-        C = x_test.shape[1]
-        H = x_test.shape[2]
-        W = x_test.shape[3]
+    def emit_batch_loop(self, x_test=None):
+        # FIXME: Dimensions from test data
+        N = 1#x_test.shape[0]
+        C = 2#x_test.shape[1]
+        H = 3#x_test.shape[2]
+        W = 4#x_test.shape[3]
 
         loop_str = ""
         loop_str += "\nstartMemTracking(); \n\n"
@@ -105,8 +241,8 @@ class GraphCodeGen(object):
     def emit_batch_loop_end(self):
         end_loop_str = ""
         end_loop_str += "\nuint32_t* labels = readLabelsBatch3(labels_path.c_str(),start,end); \n"
-        last_node = self.dfg.last_node
-        output_var = self.output_map[last_node.layer_name]
+        #last_node = self.dfg.last_node
+        output_var = self.tensors[self.graph.output[0].name].get_mapped_name()
         accuracy_call = "\nfloat accuracy = computeAccuracy3(labels, " + \
             output_var + "); \n"
         end_loop_str += accuracy_call
@@ -130,10 +266,10 @@ class GraphCodeGen(object):
     # program with HPVM intrinsics
     ################################################
 
-    def compile(self, src_dir):
-        if os.path.exists(self.weights_dir):
-            raise ValueError("Weight dir existed. Compilation interrupted!")
-        os.mkdir(weights_dir)
+    def compile(self):
+        #if os.path.exists(self.weights_dir):
+        #    raise ValueError("Weight dir existed. Compilation interrupted!")
+        #os.mkdir(self.weights_dir)
         self.emit_header()
         self.emit_weights()
         self.emit_batch_loop()
@@ -141,4 +277,4 @@ class GraphCodeGen(object):
         self.emit_batch_loop_end()
         self.emit_footer()
         # Write the program to source/disk
-        self.emit_source(src_dir)
+        self.emit_source(self.weights_dir)
diff --git a/hpvm/projects/onnx/frontend/main.py b/hpvm/projects/onnx/frontend/main.py
index 89da216b65eece7b554ed787382597574d567178..84e17747a473d6ea5e4553723d18efe6f5007c8f 100644
--- a/hpvm/projects/onnx/frontend/main.py
+++ b/hpvm/projects/onnx/frontend/main.py
@@ -3,44 +3,47 @@ import sys
 import numpy as np
 import onnx
 import glob
-
-from onnx import version_converter
 #from onnxruntime.backend.backend import OnnxRuntimeBackend as backend
 
-# onnx2hpvm modules
-from graph_builder import GraphBuilder
-from graph_codegen import GraphCodeGen
-# from approx_codegen import GraphCodeGen
-
-
-def convert_version(model, new_version):
-
-    print('The model before conversion:\n{}'.format(model))
-
-    # A full list of supported adapters can be found here:
-    # https://github.com/onnx/onnx/blob/master/onnx/version_converter.py#L21
-    # Apply the version conversion on the original model
-    converted_model = version_converter.convert_version(model, new_version)
-
-    print('The model after conversion:\n{}'.format(converted_model))
-    return converted_model
-
-
-def main():
-    model = onnx.load('../models/keras/lenet.onnx')
-    weights_dir = './test_src'
-    # test_data_dir = '../models/mnist/test_data_set_0'
-    # converted_model = convert_version(model)
+def check_version(model, new_version):
     try:
         opset = model.opset_import[0].version if model.opset_import else 1
     except AttributeError:
         opset = 1  # default opset version set to 1 if not specified
     print("opset version: ", opset)
-    gBuilder = GraphBuilder(model, None, "float32", opset, weights_dir)
-    gBuilder.build_graph()
-    #gCodegen = GraphCodeGen(gBuilder.build_graph())
-    # gCodegen.codegen(weights_dir, test_data)#, test_labels)
+    if opset != new_version:
+        #print('The model before conversion:\n{}'.format(model))
+
+        # A full list of supported adapters can be found here:
+        # https://github.com/onnx/onnx/blob/master/onnx/version_converter.py#L21
+        # Apply the version conversion on the original model
+        from onnx import version_converter
+        try:
+            converted_model = version_converter.convert_version(model, new_version)
+            return converted_model
+        except RuntimeError as e:
+            print("Current version {} of ONNX model not supported!".format(opset))
+            print("Coversion failed with message below:")
+            raise e
+        #print('The model after conversion:\n{}'.format(converted_model))
+    return model
+
+def compile(model):
+    weights_dir = './test_src'
+    opset_version_default = 11
+    # test_data_dir = '../models/mnist/test_data_set_0'
+    # converted_model = convert_version(model)
+    # model = check_version(model, 11)
+    from graph_builder import GraphBuilder
+    from graph_codegen import GraphCodeGen
+    gBuilder = GraphBuilder(model, None, "float32", weights_dir)
+    gCodegen = GraphCodeGen(gBuilder.build_graph(), weights_dir)
+    gCodegen.compile()
 
+def main():
+    model = onnx.load('../models/keras/resnet.onnx')
+    # model = onnx.load('../models/keras/vgg16_cifar10.onnx')
+    compile(model)
 
 if __name__ == "__main__":
     main()