diff --git a/hpvm/projects/onnx/frontend/config.py b/hpvm/projects/onnx/frontend/config.py
index 55ae0cebf47e42ca24f35bb80fb60c1ec15cb140..8f42dafa4ee5591f095f410f1bfb390c6e12ace4 100644
--- a/hpvm/projects/onnx/frontend/config.py
+++ b/hpvm/projects/onnx/frontend/config.py
@@ -1,6 +1,7 @@
 model_name = "alexnet"
+compile_type = 1 # 0 for HPVM Tensor Runtime, 1 for HPVM C Interface
 input_size = [1,2,3,4]
-onnx_file_dir = "../models/keras/alexnet.onnx"
+onnx_file_dir = "../models/keras/lenet.onnx"
 opset_version_default = 10
 src_emit_dir = "./test_src"
 
diff --git a/hpvm/projects/onnx/frontend/graph_ir.py b/hpvm/projects/onnx/frontend/graph_ir.py
index 745e5d0429fc7d8c6c11c8efb426160f5a6ccabd..b14de79439b3a7f6c28cf8a88594af94dc773fe7 100644
--- a/hpvm/projects/onnx/frontend/graph_ir.py
+++ b/hpvm/projects/onnx/frontend/graph_ir.py
@@ -52,6 +52,10 @@ class AddNode(DFGNode):
         self.inst_str += "tensorAdd(" + left_input + ", " + right_input + "); \n"
         return self.inst_str
 
+    def hpvm_codegen(self, tensors):
+        return "  void *r = __visc__tensor_add(t1, t2); \n"
+
+
 
 class MatMulNode(DFGNode):
 
@@ -65,6 +69,9 @@ class MatMulNode(DFGNode):
             ", " + right_input + "); \n"
         return self.inst_str
 
+    def hpvm_codegen(self, tensors):
+        return "  void *r = __visc__tensor_mul(t1, t2); \n"
+
 
 class SoftMaxNode(DFGNode):
 
@@ -76,6 +83,9 @@ class SoftMaxNode(DFGNode):
         self.inst_str += "tensorSoftmax(" + mapped_input_name + "); \n"
         return self.inst_str
 
+    def hpvm_codegen(self, tensors):
+        return "  void* r = __visc__tensor_softmax(t1); \n"
+
 
 class Conv2DNode(DFGNode):
 
@@ -112,6 +122,24 @@ class Conv2DNode(DFGNode):
             self.inst_str += "); \n"
         return self.inst_str
 
+    def hpvm_codegen(self, tensors):
+        strides = list()
+        padding = 0
+        cur_node = self.onnx_node
+        for attr in cur_node.attribute:
+            if attr.name == "pads":
+                padding = attr.ints[0]
+            elif attr.name == "strides":
+                for stride in attr.ints:
+                    strides.append(stride)
+        inst_str = "  void *r = __visc__tensor_convolution(t1, t2, "
+        inst_str += str(padding) + ", "
+        inst_str += str(padding) + ", "
+        inst_str += str(strides[0]) + ", "
+        inst_str += str(strides[1]) 
+        inst_str += "); \n"
+        return inst_str
+
 
 class MaxPool2DNode(DFGNode):
 
@@ -140,6 +168,24 @@ class MaxPool2DNode(DFGNode):
         self.inst_str += "); \n"
         return self.inst_str
 
+    def hpvm_codegen(self, tensors):
+        cur_node = self.onnx_node
+        padding = 0
+        strides = list()
+        pool_size = list()
+        for attr in cur_node.attribute:
+            if attr.name == "kernel_shape":
+                for pool in attr.ints:
+                    pool_size.append(pool)
+            elif attr.name == "strides":
+                for stride in attr.ints:
+                    strides.append(stride)
+        inst_str = "  void* r = __visc__tensor_pool_max(t1, "
+        inst_str += str(pool_size[0]) + ", " + str(pool_size[1]) + ", "
+        inst_str += str(padding) + ", " + str(padding) + ", "
+        inst_str += str(strides[0]) + ", " + str(strides[1]) + "); \n"
+        return inst_str
+
 
 class AveragePool2DNode(DFGNode):
 
@@ -168,6 +214,24 @@ class AveragePool2DNode(DFGNode):
         self.inst_str += "); \n"
         return self.inst_str
 
+    def hpvm_codegen(self, tensors):
+        cur_node = self.onnx_node
+        padding = 0
+        strides = list()
+        pool_size = list()
+        for attr in cur_node.attribute:
+            if attr.name == "kernel_shape":
+                for pool in attr.ints:
+                    pool_size.append(pool)
+            elif attr.name == "strides":
+                for stride in attr.ints:
+                    strides.append(stride)
+        inst_str = "  void* r = __visc__tensor_pool_avg(t1, "
+        inst_str += str(pool_size[0]) + ", " + str(pool_size[1]) + ", "
+        inst_str += str(padding) + ", " + str(padding) + ", "
+        inst_str += str(strides[0]) + ", " + str(strides[1]) + "); \n"
+        return inst_str
+
 
 class ReluNode(DFGNode):
 
@@ -179,6 +243,9 @@ class ReluNode(DFGNode):
         self.inst_str += "tensorRelu(" + mapped_input_name + "); \n"
         return self.inst_str
 
+    def hpvm_codegen(self, tensors):
+        return "  void* r = __visc__tensor_relu(t1); \n"
+
 class TanhNode(DFGNode):
 
     def codegen(self, tensors):
@@ -189,6 +256,9 @@ class TanhNode(DFGNode):
         self.inst_str += "tensorTanh(" + mapped_input_name + "); \n"
         return self.inst_str
 
+    def hpvm_codegen(self, tensors):
+        return "  void* r = __visc__tensor_tanh(t1); \n"
+
 
 class BatchNormalizationNode(DFGNode):
 
@@ -210,6 +280,16 @@ class BatchNormalizationNode(DFGNode):
         self.inst_str += "); \n"
         return self.inst_str
 
+    def hpvm_codegen(self, tensors):
+        epsilon = ""
+        cur_node = self.onnx_node
+        for attr in cur_node.attribute:
+            if attr.name == "epsilon":
+                epsilon = str(attr.f)
+        inst_str = "  void *r = __visc__tensor_batchnorm(t1, t2, t3, t4, t5, "
+        inst_str += str(epsilon) + "); \n"
+        return inst_str
+
 
 class PadNode(DFGNode):
     pass
diff --git a/hpvm/projects/onnx/frontend/hpvm_codegen.py b/hpvm/projects/onnx/frontend/hpvm_codegen.py
index 2446e0747f77282f9eca8778a37860a176d56670..3d14d97b09fc8b45bc80a2d3ea7acd1baf58cc67 100644
--- a/hpvm/projects/onnx/frontend/hpvm_codegen.py
+++ b/hpvm/projects/onnx/frontend/hpvm_codegen.py
@@ -10,10 +10,14 @@ class HpvmCodeGen:
         self.var_cnt = 0
         self.weights_dir = weights_dir
         self.test_data_shape = test_data_shape
-        self.filter_names = dict()  # essentially weight tensors
+        # filter_names is essentially weight & 1st input tensor(s)
+        # TODO: Replace manually adding input to filter_names
+        self.filter_names = dict()  
+        self.filter_names["input"] = 1 
         for tensor in self.tensors.values():
             if isinstance(tensor, WeightTensor):
                 self.filter_names[tensor.get_mapped_name()] = 1
+
         print(self.filter_names)
 
     ################################################
@@ -41,6 +45,63 @@ class HpvmCodeGen:
         headers += "#include <tensorUtils.h> \n\n"
         self.program_str += headers
 
+    def emit_hpvm_node_structures(self):
+        def emit_hpvm_node_header(new_var, input_size):
+            node_header_str = "void " + new_var + "_node("
+            for i in range(input_size):
+              node_header_str += "void* t" + str(i + 1) + ", "
+              node_header_str += "size_t bytes_t" + str(i + 1)
+              if i < input_size - 1:
+                node_header_str += ", "
+                
+            node_header_str += ") { \n" 
+            node_header_str += "  __visc__hint(visc::CUDNN_TARGET); \n"
+            node_header_str += "  __visc__attributes(" + str(input_size) + ", "
+
+            for i in range(input_size):
+              node_header_str += "t" + str(i + 1) 
+              if i < input_size - 1:
+                node_header_str += ", "
+                  
+            node_header_str += ", 0); \n\n" 
+            return node_header_str
+
+        def emit_hpvm_node_footer(input_size):
+            node_footer_str = "  __visc__return("
+            node_footer_str += str(input_size) + ", "
+            node_footer_str += "r, "
+            node_footer_str += "(size_t) 0); \n"
+            node_footer_str += "}\n\n"
+            return node_footer_str
+
+        def emit_root_node_footer(self):
+            mapped_output_var = self.tensors[self.graph.output[0].name].get_mapped_name()
+            # Binding output of last DFG node to the Root Node output
+            root_footer_str = "\n  __visc__bindOut(" + \
+                mapped_output_var + ", 0, 0, 0); \n"
+            root_footer_str += "  __visc__bindOut(" + \
+                mapped_output_var + ", 1, 1, 0); \n"
+            root_footer_str += "\n}\n\n"
+            return root_footer_str
+        
+
+        node_str = ""
+        for node in self.nodes:
+            cur_node = node.onnx_node
+            if node.name in skip_layer:
+                mapped_output_name = self.get_last_var()
+            else:
+                mapped_output_name = self.get_new_var()
+            self.tensors[cur_node.output[0]].set_mapped_name(mapped_output_name)
+            if node.name in skip_layer:
+                continue
+            node_str += emit_hpvm_node_header(mapped_output_name, len(cur_node.input))
+            node_str += node.hpvm_codegen(self.tensors)
+            node_str += emit_hpvm_node_footer(2) # Hardcoded as in Keras frontend
+       
+        node_str += emit_root_node_footer(self)
+        self.program_str += node_str
+
     def emit_root_node_header(self):
         root_signature = "void root("
         index = 0
@@ -81,20 +142,14 @@ class HpvmCodeGen:
         self.program_str += root_struct
 
     def emit_hpvm_graph(self):
+        
+        def emit_hpvm_edge(node):
+            return ""
+
+        hpvm_graph_str = ""
         for node in self.nodes:
-            # check if all inputs of this node is mapped
-            cur_node = node.onnx_node
-            for i in cur_node.input:
-                self.tensors[i].get_mapped_name() 
-            # set var name for output node
-            if len(cur_node.output) > 1:
-                raise ValueError("Output number for a single layer larger than 1!")
-            if cur_node.op_type in skip_layer:
-                mapped_output_name = self.get_last_var()
-            else:
-                mapped_output_name = self.get_new_var()
-            self.tensors[cur_node.output[0]].set_mapped_name(mapped_output_name)
-            self.program_str += node.hpvm_codegen(self.tensors)
+            hpvm_graph_str += emit_hpvm_edge(node)
+        return hpvm_graph_str
 
     def emit_root_node_footer(self):
         mapped_output_var = self.tensors[self.graph.output[0].name].get_mapped_name()
@@ -106,9 +161,35 @@ class HpvmCodeGen:
         root_footer_str += "\n}\n\n"
         self.program_str += root_footer_str
 
-    def emit_main(self, test_data):
+    def emit_weights(self):
+        weights_str = "\n"
+        weights_str += "std::string dir_prefix = std::string(\"" + str(self.weights_dir) + "\");\n"
+        weights_str += "std::string input_path =  dir_prefix + std::string(\"input.bin\");\n"
+        weights_str += "std::string labels_path =  dir_prefix + std::string(\"labels.bin\");\n"
+        for tensor in self.tensors.values():
+            if isinstance(tensor, WeightTensor):
+                from graph_codegen import GraphCodeGen
+                weights_str += self.emit_single_weight(tensor)
+        return weights_str
+
+    def emit_single_weight(self, tensor):
+        N = tensor.shape[0]
+        C = tensor.shape[1]
+        H = tensor.shape[2]
+        W = tensor.shape[3]
+        mapped_name = tensor.get_mapped_name()
+        file_path = mapped_name + "_path" 
+        unique_file_name = file_path + ".bin"
+        weight_str = "std::string " + file_path + " = " + " dir_prefix + std::string(\""
+        weight_str += unique_file_name + "\"); \n"
+        weight_str += "void* " + mapped_name + " = " + " readTrainedWeights("
+        weight_str += file_path + ".c_str(), 0," + str(N) + "," + str(C) + "," + str(H) + "," + str(W)
+        weight_str += "); \n"
+        return weight_str
+
+    def emit_main(self):
         main_func_str = "int main(){ \n\n"
-        #main_func_str += self.weight_str
+        main_func_str += self.emit_weights()
         #main_func_str += self.input_str
         main_func_str += "\n__visc__init(); \n"
         main_func_str += "RootIn* args = static_cast<RootIn*>(malloc(sizeof(RootIn))); \n\n"
@@ -119,7 +200,7 @@ class HpvmCodeGen:
         main_func_str += "__visc__wait(dfg); \n\n"
         main_func_str += "void *result = static_cast<RootIn*>(args)->input; \n"
         main_func_str += "hpvm_request_tensor(result, 0); \n\n"
-        main_func_str += "__visc__cleanup(); \n "
+        main_func_str += "__visc__cleanup(); \n"
         main_func_str += "computeAccuracy3(labels, result); \n"
         main_func_str += "return 0; \n\n"
         main_func_str += "} \n"
@@ -133,10 +214,11 @@ class HpvmCodeGen:
 
     def compile(self):
         self.emit_header()
+        self.emit_hpvm_node_structures()
         self.emit_root_node_header()
         self.emit_root_structure()
         self.emit_hpvm_graph()
         self.emit_root_node_footer()
-        self.emit_main(self.test_data)
+        self.emit_main()
         # dump generated program string to source file
         self.emit_source(self.weights_dir)
diff --git a/hpvm/projects/onnx/frontend/main.py b/hpvm/projects/onnx/frontend/main.py
index b1dba8d21ff4f208bfe033f463c27513eb359bbd..248c52fdd3f93f20a08bec1e0ed8939dc838828a 100644
--- a/hpvm/projects/onnx/frontend/main.py
+++ b/hpvm/projects/onnx/frontend/main.py
@@ -24,21 +24,23 @@ def check_version(model, new_version):
     return model
 
 def compile(model):
-    from config import opset_version_default, src_emit_dir
-    weights_dir = src_emit_dir
-    model = check_version(model, opset_version_default)
+    from config import compile_type, input_size, opset_version_default, src_emit_dir
     from graph_builder import GraphBuilder
     from graph_codegen import GraphCodeGen
     from hpvm_codegen import HpvmCodeGen
-    from config import input_size
+    weights_dir = src_emit_dir
+    model = check_version(model, opset_version_default)
     graphBuilder = GraphBuilder(model, None, "float32", weights_dir)
-    graphCodeGen = GraphCodeGen(graphBuilder.build_graph(), weights_dir, input_size)
-    graphCodeGen.compile()
-    #hpvmCodeGen = HpvmCodeGen(graphBuilder.build_graph(), weights_dir)
-    #hpvmCodeGen.compile()
+    if compile_type == 0:
+        graphCodeGen = GraphCodeGen(graphBuilder.build_graph(), weights_dir, input_size)
+        graphCodeGen.compile()
+    elif compile_type == 1:
+        hpvmCodeGen = HpvmCodeGen(graphBuilder.build_graph(), weights_dir)
+        hpvmCodeGen.compile()
+    else:
+        raise ValueError("Wrong type of Compilation! Abort.")
 
 def main():
-    # TODO: Put it in args
     from config import onnx_file_dir
     model = onnx.load(onnx_file_dir)
     compile(model)