diff --git a/hpvm/projects/keras/keras_frontend/hpvm_dfg_translator.py b/hpvm/projects/keras/keras_frontend/hpvm_dfg_translator.py index 38d82542410d5ec00024db3d0ed21255a76c47dd..08dd896478684ab0a1955c16463d6b75bb3ae78a 100644 --- a/hpvm/projects/keras/keras_frontend/hpvm_dfg_translator.py +++ b/hpvm/projects/keras/keras_frontend/hpvm_dfg_translator.py @@ -682,7 +682,7 @@ class HPVMTranslator: input_str += "std::string input_path = test_input_path; \n" input_str += "std::string labels_path = test_labels_path; \n\n" - input_str += "if (argc >= 2 && std::string(argv[1]) == \"tune\"){ \n" + input_str += "if (runtype == \"tune\"){ \n" input_str += " input = tune_input; \n" input_str += " input_path = tune_input_path; \n" input_str += " labels_path = tune_labels_path; \n\n" @@ -701,6 +701,15 @@ class HPVMTranslator: main_func_str += self.weight_str main_func_str += self.input_str main_func_str += "\n" + HPVM_init + "(); \n" + + main_func_str += """ + +if(config_path != ""){ + llvm_hpvm_initializeRuntimeController(config_path.c_str()); +} + + """ + main_func_str += "RootIn* args = static_cast<RootIn*>(malloc(sizeof(RootIn))); \n\n" main_func_str += self.handleTuneTestData() @@ -751,7 +760,16 @@ class HPVMTranslator: tuner_main_func_str += "\nint ret = 0; \n" tuner_main_func_str += "while ((ret = fifo_wait())) { \n" - tuner_main_func_str += "\n" + HPVM_init + "(); \n\n" + tuner_main_func_str += "\n" + HPVM_init + "(); \n" + + tuner_main_func_str += """ + +if(config_path != ""){ + llvm_hpvm_initializeRuntimeController(config_path.c_str()); +} + + """ + tuner_main_func_str += "std::string input_pth = (ret == 1 ? test_input_path : tune_input_path); \n" tuner_main_func_str += "std::string labels_pth = (ret == 1 ? test_labels_path : tune_labels_path); \n" @@ -863,7 +881,7 @@ void printUsage(){ getopt_str = """ std::string runtype; - std::string config_path; + std::string config_path = ""; int flag; while ( (flag = getopt (argc, argv, "d:c:")) != -1){ switch (flag)