diff --git a/hpvm/projects/torch2hpvm/torch2hpvm/template_hpvm.cpp.in b/hpvm/projects/torch2hpvm/torch2hpvm/template_hpvm.cpp.in index d74893984561297c9dc60d43a83d2c677885f87e..208cdfe6169f6baae95522720b8c850aff12a3a0 100644 --- a/hpvm/projects/torch2hpvm/torch2hpvm/template_hpvm.cpp.in +++ b/hpvm/projects/torch2hpvm/torch2hpvm/template_hpvm.cpp.in @@ -58,23 +58,38 @@ typedef struct __attribute__((__packed__)) { struct ret_t r; } RootIn; +void printUsage(const std::string &bin_name) { + std::cerr << "Usage: " << bin_name << "[-d {test|tune}] [-c CONF_FILE]\n"; +} const int batch_size = {{batch_size}}, input_size = {{input_size}}, batch_count = input_size / batch_size; -int main(int argc, char *argv[]){ - if (argc != 2) { - std::cout << "Usage: " << argv[0] << " {tune|test}\n"; - return 1; - } - std::string arg1 = argv[1]; - if (arg1 != "tune" && arg1 != "test") { - std::cout << "Usage: " << argv[0] << " {tune|test}\n"; - return 1; +int main(int argc, char *argv[]) { + std::string config_path = "", runtype = "test"; + int flag; + while ((flag = getopt(argc, argv, "hc:")) != -1) { + switch (flag) { + case 'd': + runtype = std::string(optarg); + if (runtype != "test" && runtype != "tune") + printUsage(argv[0]); + return 1; + break; + case 'c': + config_path = std::string(optarg); + break; + case 'h': + printUsage(argv[0]); + return 0; + default: + printUsage(argv[0]); + return 1; + } } std::string dir_prefix = "{{prefix}}/"; - std::string input_path = dir_prefix + arg1 + "_input.bin"; - std::string labels_path = dir_prefix + arg1 + "_labels.bin"; + std::string input_path = dir_prefix + "test_input.bin"; + std::string labels_path = dir_prefix + "test_labels.bin"; {% for w in weights %} std::string {{w.name}}_path = dir_prefix + "{{w.filename}}"; void* {{w.name}} = readTrainedWeights({{w.name}}_path.c_str(), 0, {{w.shape|join(', ')}}); @@ -88,6 +103,10 @@ int main(int argc, char *argv[]){ {% endfor %} __hpvm__init(); + if (config_path != "") { + llvm_hpvm_initializeRuntimeController(config_path.c_str()); + } + startMemTracking(); #pragma clang loop unroll(disable) for (int i = 0; i < batch_count; i++){