Skip to content
Snippets Groups Projects
Commit ceadfa42 authored by Yifan Zhao's avatar Yifan Zhao
Browse files

Frontend binary should respect -c flag

parent 424a32a2
No related branches found
No related tags found
No related merge requests found
...@@ -58,23 +58,38 @@ typedef struct __attribute__((__packed__)) { ...@@ -58,23 +58,38 @@ typedef struct __attribute__((__packed__)) {
struct ret_t r; struct ret_t r;
} RootIn; } RootIn;
void printUsage(const std::string &bin_name) {
std::cerr << "Usage: " << bin_name << "[-d {test|tune}] [-c CONF_FILE]\n";
}
const int batch_size = {{batch_size}}, input_size = {{input_size}}, batch_count = input_size / batch_size; const int batch_size = {{batch_size}}, input_size = {{input_size}}, batch_count = input_size / batch_size;
int main(int argc, char *argv[]){ int main(int argc, char *argv[]) {
if (argc != 2) { std::string config_path = "", runtype = "test";
std::cout << "Usage: " << argv[0] << " {tune|test}\n"; int flag;
return 1; while ((flag = getopt(argc, argv, "hc:")) != -1) {
} switch (flag) {
std::string arg1 = argv[1]; case 'd':
if (arg1 != "tune" && arg1 != "test") { runtype = std::string(optarg);
std::cout << "Usage: " << argv[0] << " {tune|test}\n"; if (runtype != "test" && runtype != "tune")
return 1; printUsage(argv[0]);
return 1;
break;
case 'c':
config_path = std::string(optarg);
break;
case 'h':
printUsage(argv[0]);
return 0;
default:
printUsage(argv[0]);
return 1;
}
} }
std::string dir_prefix = "{{prefix}}/"; std::string dir_prefix = "{{prefix}}/";
std::string input_path = dir_prefix + arg1 + "_input.bin"; std::string input_path = dir_prefix + "test_input.bin";
std::string labels_path = dir_prefix + arg1 + "_labels.bin"; std::string labels_path = dir_prefix + "test_labels.bin";
{% for w in weights %} {% for w in weights %}
std::string {{w.name}}_path = dir_prefix + "{{w.filename}}"; std::string {{w.name}}_path = dir_prefix + "{{w.filename}}";
void* {{w.name}} = readTrainedWeights({{w.name}}_path.c_str(), 0, {{w.shape|join(', ')}}); void* {{w.name}} = readTrainedWeights({{w.name}}_path.c_str(), 0, {{w.shape|join(', ')}});
...@@ -88,6 +103,10 @@ int main(int argc, char *argv[]){ ...@@ -88,6 +103,10 @@ int main(int argc, char *argv[]){
{% endfor %} {% endfor %}
__hpvm__init(); __hpvm__init();
if (config_path != "") {
llvm_hpvm_initializeRuntimeController(config_path.c_str());
}
startMemTracking(); startMemTracking();
#pragma clang loop unroll(disable) #pragma clang loop unroll(disable)
for (int i = 0; i < batch_count; i++){ for (int i = 0; i < batch_count; i++){
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment