Skip to content
Snippets Groups Projects
Commit fc2c9afb authored by shingjan's avatar shingjan
Browse files

template ready, working on hpvm codegen

parent 5744a820
No related branches found
No related tags found
No related merge requests found
class GraphCodeGen:
def __init__(self, graph):
self._headers = ""
self._nodes = ""
self._root = ""
self._root_struct = ""
self._main_func = ""
class HpvmCodeGen:
def __init__(self, DFG, weights_dir, test_data=None, test_labels=None):
self.program_str = ""
self.graph = DFG.graph
self.tensors = DFG.tensors
self.nodes = DFG.nodes
self.var_cnt = 0
self.weights_dir = weights_dir
self.test_data = test_data
self.test_labels = test_labels
self.filter_names = {} # essentially tensors
def emitHeaders(self):
def emit_header(self):
headers = "\n#include <stdio.h> \n"
headers += "#include <stdlib.h> \n"
headers += "#include <unistd.h> \n"
......@@ -16,9 +20,9 @@ class GraphCodeGen:
headers += "#include <visc.h> \n"
headers += "#include <tensorTypes.h> \n"
headers += "#include <tensorUtils.h> \n\n"
self._headers = headers
self.program_str += headers
def emitRoot(self):
def emit_root(self):
def emitRootNodeHeader():
root_signature = "void root("
index = 0
......@@ -39,8 +43,8 @@ class GraphCodeGen:
for f_name in self.filter_names:
root_signature += f_name
if index < len(self.filter_names) - 1:
root_signature += ", "
index += 1
root_signature += ", "
index += 1
root_signature += ", 0); \n\n"
return root_signature
......@@ -69,15 +73,15 @@ class GraphCodeGen:
root_struct += "}\nRootIn;\n\n"
return root_struct
self._root += emitRootNodeHeader()
self._root_struct += emitRootStructure()
self.codegen(self.dfg)
self._root += emitRootNodeFooter()
self.program_str += emitRootNodeHeader()
self.program_str += emitRootStructure()
# self.codegen(self.dfg)
self.program_str += emitRootNodeFooter()
def emitMainFunc(self, test_data):
def emit_main(self, test_data):
main_func_str = "int main(){ \n\n"
main_func_str += self.weight_str
main_func_str += self.input_str
#main_func_str += self.weight_str
#main_func_str += self.input_str
main_func_str += "\n__visc__init(); \n"
main_func_str += "RootIn* args = static_cast<RootIn*>(malloc(sizeof(RootIn))); \n\n"
for f_name in self.filter_names:
......@@ -91,19 +95,17 @@ class GraphCodeGen:
main_func_str += "computeAccuracy3(labels, result); \n"
main_func_str += "return 0; \n\n"
main_func_str += "} \n"
self._main_func += main_func_str
self.program_str += main_func_str
def emitSource(self, dir_prefix):
source = self._headers + self._nodes + self._root
source += self._root_struct + self._main_func
print(source)
def emit_source(self, dir_prefix):
print(self.program_str)
f = open(dir_prefix + "/approxhpvm_src.cc", "w+")
f.write(source)
f.write(self.program_str)
f.close()
def compile(self, model, weights_dir, test_data):
self.emitHeaders()
self.emitRoot()
self.emitMainFunc(test_data)
def compile(self):
self.emit_header()
# self.emitRoot()
self.emit_main(self.test_data)
# dump generated program string to source file
self.emitSource(weights_dir)
self.emit_source(self.weights_dir)
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment