From c8aeb8ef6a26c0134932e75fa368736fb40d6c79 Mon Sep 17 00:00:00 2001
From: Hashim Sharif <hsharif3@miranda.cs.illinois.edu>
Date: Tue, 9 Mar 2021 19:00:13 -0600
Subject: [PATCH] Fixes to Ensure reload_dir can and data_dir take absolute
 paths

---
 .../projects/keras/frontend/approxhpvm_translator.py |  2 +-
 hpvm/projects/keras/src/Benchmark.py                 |  5 ++++-
 hpvm/projects/keras/src/Config.py                    | 12 +++++++++++-
 hpvm/projects/keras/src/alexnet.py                   |  3 ++-
 4 files changed, 18 insertions(+), 4 deletions(-)

diff --git a/hpvm/projects/keras/frontend/approxhpvm_translator.py b/hpvm/projects/keras/frontend/approxhpvm_translator.py
index 59bc0d17c0..1b041a6de4 100644
--- a/hpvm/projects/keras/frontend/approxhpvm_translator.py
+++ b/hpvm/projects/keras/frontend/approxhpvm_translator.py
@@ -978,7 +978,7 @@ class TensorRtTranslator:
 
     self.add_header()
     
-    dir_path = "std::string dir_prefix = std::string(MODEL_PARAMS_DIR) + std::string(\"" + weights_dir +  "\"); \n"
+    dir_path = "std::string dir_prefix = std::string(\"" + weights_dir +  "\"); \n"
     self.weight_str += dir_path
 
     if test_data is not None:
diff --git a/hpvm/projects/keras/src/Benchmark.py b/hpvm/projects/keras/src/Benchmark.py
index 0871b74959..7b0b5447bc 100644
--- a/hpvm/projects/keras/src/Benchmark.py
+++ b/hpvm/projects/keras/src/Benchmark.py
@@ -97,10 +97,13 @@ class Benchmark:
 
       if len(argv) > 2:
         if argv[2] == "frontend":
+
+          if argv[1] == "hpvm_reload": # If reloading HPVM weights use this as directory to load from in HPVM-C generated src
+              self.data_dir = self.reload_dir
           
           # Main call to ApproxHPVM-Keras Frontend
           working_dir = translate_to_approxhpvm(model,
-                                                self.data_dir, self.src_dir,  ##  "data/test_src/", 
+                                                self.data_dir, self.src_dir,   
                                                 X_test, y_test,
                                                 X_tuner, y_tuner,
                                                 self.batch_size, # FIXIT
diff --git a/hpvm/projects/keras/src/Config.py b/hpvm/projects/keras/src/Config.py
index 2edc5c1add..99e696d632 100644
--- a/hpvm/projects/keras/src/Config.py
+++ b/hpvm/projects/keras/src/Config.py
@@ -1,3 +1,13 @@
 
+import pathlib
+
+
 # Path Relative to Model Params Directory
-MODEL_PARAMS_DIR = "../../../hpvm/test/dnn_benchmarks/model_params/"
+abs_path = pathlib.Path(__file__).parent.absolute()
+MODEL_PARAMS_DIR = str(abs_path) + "/../../../../hpvm/test/dnn_benchmarks/model_params/"
+
+
+if __name__ == "__main__":
+
+    abs_path = pathlib.Path(__file__).parent.absolute()
+    print (abs_path)
diff --git a/hpvm/projects/keras/src/alexnet.py b/hpvm/projects/keras/src/alexnet.py
index 9b4d9dfdca..1d8778727b 100644
--- a/hpvm/projects/keras/src/alexnet.py
+++ b/hpvm/projects/keras/src/alexnet.py
@@ -142,7 +142,8 @@ if __name__ == '__main__':
     ### Parameters specific to each benchmark
     reload_dir = MODEL_PARAMS_DIR + '/alexnet_cifar10/'
     keras_model_file = MODEL_PARAMS_DIR + '/alexnet_cifar10/weights.h5'
-    data_dir = '/alexnet_cifar10/' 
+    #data_dir = '/alexnet_cifar10/'
+    data_dir = ''
     src_dir = 'data/alexnet_cifar10_src/'
     num_classes = 10
     batch_size = 500
-- 
GitLab