Skip to content
Snippets Groups Projects
Commit 12e2d791 authored by Hashim Sharif's avatar Hashim Sharif
Browse files

Merging

parents 85ef7038 3f4299bb
No related branches found
No related tags found
No related merge requests found
......@@ -443,17 +443,32 @@ class ResNet18_CIFAR10(Benchmark):
mean = np.mean(X_train)
std = np.std(X_train)
# X_train = (X_train - mean) / (std + 1e-7)
# X_val = (X_val - mean) / (std + 1e-7)
# X_val = (X_val - mean) / (std + 1e-7)
X_train = (X_train - mean)
X_val = (X_val - mean)
X_val = (X_val - mean)
X_test_val = np.fromfile(MODEL_PARAMS_DIR + '/resnet18_cifar10/test_input.bin', dtype=np.float32)
Y_test_val = np.fromfile(MODEL_PARAMS_DIR + '/resnet18_cifar10/test_labels.bin', dtype=np.uint32)
X_test_val = X_test_val.reshape((-1,3,32,32))
X_tune_val = np.fromfile(MODEL_PARAMS_DIR + '/resnet18_cifar10/tune_input.bin', dtype=np.float32)
Y_tune_val = np.fromfile(MODEL_PARAMS_DIR + '/resnet18_cifar10/tune_labels.bin', dtype=np.uint32)
X_tune_val = X_tune_val.reshape((-1,3,32,32))
X_test = X_test_val[:5000]
y_test= Y_test_val[:5000]
X_tuner = X_tune_val[:5000]
y_tuner = Y_tune_val[:5000]
X_test = X_val[0:5000]
y_test = y_val[0:5000]
X_tuner = X_val[5000:]
y_tuner = y_val[5000:]
return X_train, y_train, X_test, y_test, X_tuner, y_tuner
def trainModel(self, model, X_train, y_train, X_test, y_test):
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment