diff --git a/hpvm/projects/keras/keras_frontend/approxhpvm_translator.py b/hpvm/projects/keras/keras_frontend/approxhpvm_translator.py index 24647107506704a6d049268c7e50ac498281e0e5..87a55913d997c6a893b5075ace6d07f988db0c29 100644 --- a/hpvm/projects/keras/keras_frontend/approxhpvm_translator.py +++ b/hpvm/projects/keras/keras_frontend/approxhpvm_translator.py @@ -1088,10 +1088,7 @@ def createRecursiveDir(target_dir): print ("Delete Directory or Give Different Path. Aborting....") sys.exit(1) - print (target_dir) - toks = target_dir.split("/") - print (toks) for i in range(len(toks)): path_str = "/".join(toks[0:i+1]) if path_str != "": diff --git a/hpvm/projects/keras/keras_frontend/hpvm_dfg_translator.py b/hpvm/projects/keras/keras_frontend/hpvm_dfg_translator.py index 53369478e39e058fd4e4e065d00fb55e0bdc2960..38d82542410d5ec00024db3d0ed21255a76c47dd 100644 --- a/hpvm/projects/keras/keras_frontend/hpvm_dfg_translator.py +++ b/hpvm/projects/keras/keras_frontend/hpvm_dfg_translator.py @@ -695,6 +695,9 @@ class HPVMTranslator: def genMainFunction(self, test_data, batch_size): main_func_str = "int main(int argc, char* argv[]){ \n\n" + + main_func_str += self.GetOptLoop() + main_func_str += self.weight_str main_func_str += self.input_str main_func_str += "\n" + HPVM_init + "(); \n" @@ -733,6 +736,9 @@ class HPVMTranslator: def genTunerMainFunction(self, src_dir, test_data, batch_size): tuner_main_func_str = "int main(int argc, char* argv[]){ \n\n" + + tuner_main_func_str += self.GetOptLoop() + tuner_main_func_str += self.weight_str tuner_main_func_str += self.input_str tuner_main_func_str += "RootIn* args = static_cast<RootIn*>(malloc(sizeof(RootIn))); \n\n" @@ -833,18 +839,58 @@ void write_accuracy(float accuracy) { fout << std::fixed << accuracy; } +""" + + return FIFO_str + + + def getUsageStr(self): + usage_str = """ -""" +void printUsage(){ + std::cerr << \"Usage: -d {test|tune} -c {config_file_path} \"; + abort(); +} - return FIFO_str +""" + return usage_str + def GetOptLoop(self): + + getopt_str = """ + + std::string runtype; + std::string config_path; + int flag; + while ( (flag = getopt (argc, argv, "d:c:")) != -1){ + switch (flag) + { + case 'd': + runtype = std::string(optarg); + if (runtype != "test" && runtype != "tune") + printUsage(); + break; + case 'c': + config_path = std::string(optarg); + break; + default: + printUsage(); + } + } + +""" + + return getopt_str + + + def generateTestProgram(self, dir_prefix): program_str = self.file_header_str + self.node_str + self.root_str - program_str += self.root_struct_str + self.main_func_str + program_str += self.root_struct_str + self.getUsageStr() + self.main_func_str DEBUG (program_str) @@ -856,8 +902,8 @@ void write_accuracy(float accuracy) { def generateTunerProgram(self, dir_prefix, FIFO_str): - program_str = self.file_header_str + FIFO_str + self.node_str + self.root_str - program_str += self.root_struct_str + self.tuner_main_func_str + program_str = self.file_header_str + FIFO_str + self.node_str + self.root_str + program_str += self.root_struct_str + self.getUsageStr() + self.tuner_main_func_str DEBUG (program_str)