Skip to content
Snippets Groups Projects
Commit c8aeb8ef authored by Hashim Sharif's avatar Hashim Sharif
Browse files

Fixes to Ensure reload_dir can and data_dir take absolute paths

parent 10db78d0
No related branches found
No related tags found
No related merge requests found
......@@ -978,7 +978,7 @@ class TensorRtTranslator:
self.add_header()
dir_path = "std::string dir_prefix = std::string(MODEL_PARAMS_DIR) + std::string(\"" + weights_dir + "\"); \n"
dir_path = "std::string dir_prefix = std::string(\"" + weights_dir + "\"); \n"
self.weight_str += dir_path
if test_data is not None:
......
......@@ -97,10 +97,13 @@ class Benchmark:
if len(argv) > 2:
if argv[2] == "frontend":
if argv[1] == "hpvm_reload": # If reloading HPVM weights use this as directory to load from in HPVM-C generated src
self.data_dir = self.reload_dir
# Main call to ApproxHPVM-Keras Frontend
working_dir = translate_to_approxhpvm(model,
self.data_dir, self.src_dir, ## "data/test_src/",
self.data_dir, self.src_dir,
X_test, y_test,
X_tuner, y_tuner,
self.batch_size, # FIXIT
......
import pathlib
# Path Relative to Model Params Directory
MODEL_PARAMS_DIR = "../../../hpvm/test/dnn_benchmarks/model_params/"
abs_path = pathlib.Path(__file__).parent.absolute()
MODEL_PARAMS_DIR = str(abs_path) + "/../../../../hpvm/test/dnn_benchmarks/model_params/"
if __name__ == "__main__":
abs_path = pathlib.Path(__file__).parent.absolute()
print (abs_path)
......@@ -142,7 +142,8 @@ if __name__ == '__main__':
### Parameters specific to each benchmark
reload_dir = MODEL_PARAMS_DIR + '/alexnet_cifar10/'
keras_model_file = MODEL_PARAMS_DIR + '/alexnet_cifar10/weights.h5'
data_dir = '/alexnet_cifar10/'
#data_dir = '/alexnet_cifar10/'
data_dir = ''
src_dir = 'data/alexnet_cifar10_src/'
num_classes = 10
batch_size = 500
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment