From 3f4299bbcba90987948280a1800ac26c3727f9d0 Mon Sep 17 00:00:00 2001
From: Abdul Rafae Noor <arnoor2@tyler.cs.illinois.edu>
Date: Tue, 2 Feb 2021 12:12:55 -0600
Subject: [PATCH] Test changing model to use test.bin data

---
 hpvm/projects/keras/src/resnet18_cifar10.py | 29 ++++++++++++++++-----
 1 file changed, 22 insertions(+), 7 deletions(-)

diff --git a/hpvm/projects/keras/src/resnet18_cifar10.py b/hpvm/projects/keras/src/resnet18_cifar10.py
index 1367c0830b..74abc7ad9f 100644
--- a/hpvm/projects/keras/src/resnet18_cifar10.py
+++ b/hpvm/projects/keras/src/resnet18_cifar10.py
@@ -443,17 +443,32 @@ class ResNet18_CIFAR10(Benchmark):
         mean = np.mean(X_train)
         std = np.std(X_train)
 #         X_train = (X_train - mean) / (std + 1e-7)
-#         X_val = (X_val - mean) / (std + 1e-7)  
+#         X_val = (X_val - mean) / (std + 1e-7)
         X_train = (X_train - mean)
-        X_val = (X_val - mean) 
+        X_val = (X_val - mean)
+
+
+        X_test_val = np.fromfile(MODEL_PARAMS_DIR + '/resnet18_cifar10/test_input.bin', dtype=np.float32)
+        Y_test_val = np.fromfile(MODEL_PARAMS_DIR + '/resnet18_cifar10/test_labels.bin', dtype=np.uint32)
+
+        X_test_val = X_test_val.reshape((-1,3,32,32))
+
+
+        X_tune_val = np.fromfile(MODEL_PARAMS_DIR + '/resnet18_cifar10/tune_input.bin', dtype=np.float32)
+        Y_tune_val = np.fromfile(MODEL_PARAMS_DIR + '/resnet18_cifar10/tune_labels.bin', dtype=np.uint32)
+
+        X_tune_val = X_tune_val.reshape((-1,3,32,32))
+
+
+        X_test = X_test_val[:5000]
+        y_test= Y_test_val[:5000]
+
+        X_tuner = X_tune_val[:5000]
+        y_tuner = Y_tune_val[:5000]
 
-        X_test = X_val[0:5000]
-        y_test = y_val[0:5000]
-        X_tuner = X_val[5000:]
-        y_tuner = y_val[5000:]
 
         return X_train, y_train, X_test, y_test, X_tuner, y_tuner
-    
+
 
     def trainModel(self, model, X_train, y_train, X_test, y_test):
 
-- 
GitLab