From ceadfa42bb9ed591071012ac4a1650a73b14af74 Mon Sep 17 00:00:00 2001
From: Yifan Zhao <yifanz16@illinois.edu>
Date: Sun, 4 Apr 2021 03:43:21 -0500
Subject: [PATCH] Frontend binary should respect -c flag

---
 .../torch2hpvm/template_hpvm.cpp.in           | 41 ++++++++++++++-----
 1 file changed, 30 insertions(+), 11 deletions(-)

diff --git a/hpvm/projects/torch2hpvm/torch2hpvm/template_hpvm.cpp.in b/hpvm/projects/torch2hpvm/torch2hpvm/template_hpvm.cpp.in
index d748939845..208cdfe616 100644
--- a/hpvm/projects/torch2hpvm/torch2hpvm/template_hpvm.cpp.in
+++ b/hpvm/projects/torch2hpvm/torch2hpvm/template_hpvm.cpp.in
@@ -58,23 +58,38 @@ typedef struct __attribute__((__packed__)) {
   struct ret_t r;
 } RootIn;
 
+void printUsage(const std::string &bin_name) {
+  std::cerr << "Usage: " << bin_name << "[-d {test|tune}] [-c CONF_FILE]\n";
+}
 
 const int batch_size = {{batch_size}}, input_size = {{input_size}}, batch_count = input_size / batch_size;
 
-int main(int argc, char *argv[]){
-  if (argc != 2) {
-    std::cout << "Usage: " << argv[0] << " {tune|test}\n";
-    return 1;
-  }
-  std::string arg1 = argv[1];
-  if (arg1 != "tune" && arg1 != "test") {
-    std::cout << "Usage: " << argv[0] << " {tune|test}\n";
-    return 1;
+int main(int argc, char *argv[]) {
+  std::string config_path = "", runtype = "test";
+  int flag;
+  while ((flag = getopt(argc, argv, "hc:")) != -1) {
+    switch (flag) {
+    case 'd':
+      runtype = std::string(optarg);
+      if (runtype != "test" && runtype != "tune")
+        printUsage(argv[0]);
+        return 1;
+      break;
+    case 'c':
+      config_path = std::string(optarg);
+      break;
+    case 'h':
+      printUsage(argv[0]);
+      return 0;
+    default:
+      printUsage(argv[0]);
+      return 1;
+    }
   }
 
   std::string dir_prefix = "{{prefix}}/";
-  std::string input_path = dir_prefix + arg1 + "_input.bin";
-  std::string labels_path = dir_prefix + arg1 + "_labels.bin";
+  std::string input_path = dir_prefix + "test_input.bin";
+  std::string labels_path = dir_prefix + "test_labels.bin";
 {% for w in weights %}
   std::string {{w.name}}_path = dir_prefix + "{{w.filename}}";
   void* {{w.name}} = readTrainedWeights({{w.name}}_path.c_str(), 0, {{w.shape|join(', ')}});
@@ -88,6 +103,10 @@ int main(int argc, char *argv[]){
 {% endfor %}
 
   __hpvm__init();
+  if (config_path != "") {
+    llvm_hpvm_initializeRuntimeController(config_path.c_str());
+  }
+
   startMemTracking();
 #pragma clang loop unroll(disable)
   for (int i = 0; i < batch_count; i++){
-- 
GitLab