From 8ba057dd2f2b7bd2522dd7cae9d22f7859f9dcd8 Mon Sep 17 00:00:00 2001 From: shingjan <yjshi03@gmail.com> Date: Thu, 4 Jun 2020 23:07:47 -0500 Subject: [PATCH] address comments and FIXME from Tue code walkthrough --- hpvm/projects/onnx/frontend/config.py | 5 +++- hpvm/projects/onnx/frontend/graph_builder.py | 5 ++-- hpvm/projects/onnx/frontend/graph_codegen.py | 26 +++++++++----------- hpvm/projects/onnx/frontend/main.py | 3 ++- 4 files changed, 21 insertions(+), 18 deletions(-) diff --git a/hpvm/projects/onnx/frontend/config.py b/hpvm/projects/onnx/frontend/config.py index b86adfa1ce..c8053b8d98 100644 --- a/hpvm/projects/onnx/frontend/config.py +++ b/hpvm/projects/onnx/frontend/config.py @@ -1,3 +1,6 @@ +model_name = "alexnet" +input_size = [1,2,3,4] onnx_file_dir = "../models/keras/alexnet.onnx" +opset_version_default = 11 src_emit_dir = "./test_src" -opset_version_default = 11 \ No newline at end of file + diff --git a/hpvm/projects/onnx/frontend/graph_builder.py b/hpvm/projects/onnx/frontend/graph_builder.py index 063c8d83e3..737eb1d3e1 100644 --- a/hpvm/projects/onnx/frontend/graph_builder.py +++ b/hpvm/projects/onnx/frontend/graph_builder.py @@ -112,11 +112,12 @@ class GraphBuilder(object): self.tensors[weight_tensor.name].set_mapped_name("weight_" + str(weight_cnt)) weight_cnt += 1 # parse input + input_cnt = 0 for i in self.graph.input: if i.name not in self.tensors: self.tensors[i.name] = InputTensor(i.name) - # FIXME: This input name is hardcoded - self.tensors[i.name].set_mapped_name("input") + self.tensors[i.name].set_mapped_name("input_" + str(input_cnt)) + input_cnt += 1 # parse intermediate tensor for node in self.graph.node: op_name = node.op_type diff --git a/hpvm/projects/onnx/frontend/graph_codegen.py b/hpvm/projects/onnx/frontend/graph_codegen.py index b32132959c..f47abfefab 100644 --- a/hpvm/projects/onnx/frontend/graph_codegen.py +++ b/hpvm/projects/onnx/frontend/graph_codegen.py @@ -6,15 +6,14 @@ from tensor import WeightTensor from utils import skip_layer class GraphCodeGen(object): - def __init__(self, dfg, weights_dir, test_data=None, test_labels=None): + def __init__(self, dfg, weights_dir, test_data_shape=None): self.program_str = "" self.graph = dfg.graph self.tensors = dfg.tensors self.nodes = dfg.nodes self.var_cnt = 0 self.weights_dir = weights_dir - self.test_data = test_data - self.test_labels = test_labels + self.test_data_shape = test_data_shape ################################################ # Aux functions @@ -211,18 +210,18 @@ class GraphCodeGen(object): self.program_str += destructors self.program_str += end_main - def emit_batch_loop(self, x_test=None): - # FIXME: Dimensions from test data not available in ONNX - N = 1#x_test.shape[0] - C = 1#x_test.shape[1] - H = 1#x_test.shape[2] - W = 1#x_test.shape[3] + def emit_batch_loop(self, test_data_shape): + N = test_data_shape[0] + C = test_data_shape[1] + H = test_data_shape[2] + W = test_data_shape[3] loop_str = "" loop_str += "\nstartMemTracking(); \n\n" loop_str += "int test_input_size = " + str(N) + "; \n" loop_str += "int batch_size = " + str(N) + "; \n" + # FIXME: Ceiling for batch_count loop_str += "int batch_count = test_input_size / batch_size; \n" loop_str += "float final_accuracy = 0.0; \n\n" @@ -253,14 +252,14 @@ class GraphCodeGen(object): self.program_str += end_loop_str - def emit_source(self, src_dir): + def emit_source_to_file(self, src_dir): f = open(src_dir + "/src.cc", "w+") f.write(self.program_str) f.close() ################################################ # Compile is a top level function to compile an onnx model into C/C++ - # program with HPVM intrinsics + # program with HPVM Tensor Runtime ################################################ def compile(self): @@ -269,9 +268,8 @@ class GraphCodeGen(object): #os.mkdir(self.weights_dir) self.emit_header() self.emit_weights() - self.emit_batch_loop() + self.emit_batch_loop(self.test_data_shape) self.emit_graph() self.emit_batch_loop_end() self.emit_footer() - # Write the program to source/disk - self.emit_source(self.weights_dir) + self.emit_source_to_file(self.weights_dir) diff --git a/hpvm/projects/onnx/frontend/main.py b/hpvm/projects/onnx/frontend/main.py index 16bf8868b0..6014569acd 100644 --- a/hpvm/projects/onnx/frontend/main.py +++ b/hpvm/projects/onnx/frontend/main.py @@ -30,8 +30,9 @@ def compile(model): from graph_builder import GraphBuilder from graph_codegen import GraphCodeGen from hpvm_codegen import HpvmCodeGen + from config import input_size graphBuilder = GraphBuilder(model, None, "float32", weights_dir) - graphCodeGen = GraphCodeGen(graphBuilder.build_graph(), weights_dir) + graphCodeGen = GraphCodeGen(graphBuilder.build_graph(), weights_dir, input_size) graphCodeGen.compile() #hpvmCodeGen = HpvmCodeGen(graphBuilder.build_graph(), weights_dir) #hpvmCodeGen.compile() -- GitLab