Skip to content
Snippets Groups Projects
Commit 8ba057dd authored by shingjan's avatar shingjan
Browse files

address comments and FIXME from Tue code walkthrough

parent 29c1b0ff
No related branches found
No related tags found
No related merge requests found
model_name = "alexnet"
input_size = [1,2,3,4]
onnx_file_dir = "../models/keras/alexnet.onnx" onnx_file_dir = "../models/keras/alexnet.onnx"
opset_version_default = 11
src_emit_dir = "./test_src" src_emit_dir = "./test_src"
opset_version_default = 11
\ No newline at end of file
...@@ -112,11 +112,12 @@ class GraphBuilder(object): ...@@ -112,11 +112,12 @@ class GraphBuilder(object):
self.tensors[weight_tensor.name].set_mapped_name("weight_" + str(weight_cnt)) self.tensors[weight_tensor.name].set_mapped_name("weight_" + str(weight_cnt))
weight_cnt += 1 weight_cnt += 1
# parse input # parse input
input_cnt = 0
for i in self.graph.input: for i in self.graph.input:
if i.name not in self.tensors: if i.name not in self.tensors:
self.tensors[i.name] = InputTensor(i.name) self.tensors[i.name] = InputTensor(i.name)
# FIXME: This input name is hardcoded self.tensors[i.name].set_mapped_name("input_" + str(input_cnt))
self.tensors[i.name].set_mapped_name("input") input_cnt += 1
# parse intermediate tensor # parse intermediate tensor
for node in self.graph.node: for node in self.graph.node:
op_name = node.op_type op_name = node.op_type
......
...@@ -6,15 +6,14 @@ from tensor import WeightTensor ...@@ -6,15 +6,14 @@ from tensor import WeightTensor
from utils import skip_layer from utils import skip_layer
class GraphCodeGen(object): class GraphCodeGen(object):
def __init__(self, dfg, weights_dir, test_data=None, test_labels=None): def __init__(self, dfg, weights_dir, test_data_shape=None):
self.program_str = "" self.program_str = ""
self.graph = dfg.graph self.graph = dfg.graph
self.tensors = dfg.tensors self.tensors = dfg.tensors
self.nodes = dfg.nodes self.nodes = dfg.nodes
self.var_cnt = 0 self.var_cnt = 0
self.weights_dir = weights_dir self.weights_dir = weights_dir
self.test_data = test_data self.test_data_shape = test_data_shape
self.test_labels = test_labels
################################################ ################################################
# Aux functions # Aux functions
...@@ -211,18 +210,18 @@ class GraphCodeGen(object): ...@@ -211,18 +210,18 @@ class GraphCodeGen(object):
self.program_str += destructors self.program_str += destructors
self.program_str += end_main self.program_str += end_main
def emit_batch_loop(self, x_test=None): def emit_batch_loop(self, test_data_shape):
# FIXME: Dimensions from test data not available in ONNX N = test_data_shape[0]
N = 1#x_test.shape[0] C = test_data_shape[1]
C = 1#x_test.shape[1] H = test_data_shape[2]
H = 1#x_test.shape[2] W = test_data_shape[3]
W = 1#x_test.shape[3]
loop_str = "" loop_str = ""
loop_str += "\nstartMemTracking(); \n\n" loop_str += "\nstartMemTracking(); \n\n"
loop_str += "int test_input_size = " + str(N) + "; \n" loop_str += "int test_input_size = " + str(N) + "; \n"
loop_str += "int batch_size = " + str(N) + "; \n" loop_str += "int batch_size = " + str(N) + "; \n"
# FIXME: Ceiling for batch_count
loop_str += "int batch_count = test_input_size / batch_size; \n" loop_str += "int batch_count = test_input_size / batch_size; \n"
loop_str += "float final_accuracy = 0.0; \n\n" loop_str += "float final_accuracy = 0.0; \n\n"
...@@ -253,14 +252,14 @@ class GraphCodeGen(object): ...@@ -253,14 +252,14 @@ class GraphCodeGen(object):
self.program_str += end_loop_str self.program_str += end_loop_str
def emit_source(self, src_dir): def emit_source_to_file(self, src_dir):
f = open(src_dir + "/src.cc", "w+") f = open(src_dir + "/src.cc", "w+")
f.write(self.program_str) f.write(self.program_str)
f.close() f.close()
################################################ ################################################
# Compile is a top level function to compile an onnx model into C/C++ # Compile is a top level function to compile an onnx model into C/C++
# program with HPVM intrinsics # program with HPVM Tensor Runtime
################################################ ################################################
def compile(self): def compile(self):
...@@ -269,9 +268,8 @@ class GraphCodeGen(object): ...@@ -269,9 +268,8 @@ class GraphCodeGen(object):
#os.mkdir(self.weights_dir) #os.mkdir(self.weights_dir)
self.emit_header() self.emit_header()
self.emit_weights() self.emit_weights()
self.emit_batch_loop() self.emit_batch_loop(self.test_data_shape)
self.emit_graph() self.emit_graph()
self.emit_batch_loop_end() self.emit_batch_loop_end()
self.emit_footer() self.emit_footer()
# Write the program to source/disk self.emit_source_to_file(self.weights_dir)
self.emit_source(self.weights_dir)
...@@ -30,8 +30,9 @@ def compile(model): ...@@ -30,8 +30,9 @@ def compile(model):
from graph_builder import GraphBuilder from graph_builder import GraphBuilder
from graph_codegen import GraphCodeGen from graph_codegen import GraphCodeGen
from hpvm_codegen import HpvmCodeGen from hpvm_codegen import HpvmCodeGen
from config import input_size
graphBuilder = GraphBuilder(model, None, "float32", weights_dir) graphBuilder = GraphBuilder(model, None, "float32", weights_dir)
graphCodeGen = GraphCodeGen(graphBuilder.build_graph(), weights_dir) graphCodeGen = GraphCodeGen(graphBuilder.build_graph(), weights_dir, input_size)
graphCodeGen.compile() graphCodeGen.compile()
#hpvmCodeGen = HpvmCodeGen(graphBuilder.build_graph(), weights_dir) #hpvmCodeGen = HpvmCodeGen(graphBuilder.build_graph(), weights_dir)
#hpvmCodeGen.compile() #hpvmCodeGen.compile()
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment