Skip to content
Snippets Groups Projects
Commit 317b22f6 authored by Hashim Sharif's avatar Hashim Sharif
Browse files

Adding model reloading and training options to AlexNet Keras script

parent 49cff2c8
No related branches found
No related tags found
No related merge requests found
......@@ -38,7 +38,7 @@ def lr_schedule(epoch):
def buildModel2():
def buildModel():
activation_type = "tanh"
weight_decay = 1e-4
......@@ -73,7 +73,7 @@ def buildModel2():
def buildModel():
def buildModel_old():
model = Sequential()
model.add(Conv2D(128, kernel_size=(3, 3), activation='tanh', input_shape=(3, 32, 32), padding = 'same'))
......@@ -157,8 +157,6 @@ def trainModel(model):
#dumpCalibrationData("calibration_data/alexnet_calib.bin", X_train,
# "calibration_data/alexnet_train_labels.bin", train_labels)
translate_to_approxhpvm(model, "data/alexnet_cifar10/", X_test, test_labels, 10)
......@@ -194,26 +192,39 @@ def data_preprocess():
if __name__ == "__main__":
if len(sys.argv) < 2:
sys.exit(0)
os.environ["CUDA_VISIBLE_DEVICES"] = "0"
# Changing to NCHW format
K.set_image_data_format('channels_first')
model = buildModel2()
model = buildModel()
X_train, Y_train, X_test, Y_test = data_preprocess()
###reloadFP32HPVMModel(model, "/home/hsharif3/Gitlab/hpvm/llvm/projects/hpvm-tensor-rt/model_params/alexnet_cifar10/")
reload_dir = "/home/hsharif3/Gitlab/hpvm/llvm/projects/hpvm-tensor-rt/model_params/alexnet_cifar10/"
keras_model_file = "alexnet.h5"
model = dumpHPVMToKerasModel(model, reload_dir, keras_model_file, X_test, Y_test)
if sys.argv[1] == "hpvm_reload":
model = dumpHPVMToKerasModel(model, reload_dir, keras_model_file, X_test, Y_test)
if sys.argv[1] == "keras_reload":
model = load_model(keras_model_file)
if sys.argv[1] == "train":
model = trainModel(model)
num_classes = 10
score = model.evaluate(X_test, to_categorical(Y_test, num_classes), verbose=0)
print('Test accuracy2:', score[1])
reloadKerasModel(keras_model_file)
if len(sys.argv) > 2 and sys.argv[2] == "frontend":
if sys.argv[1] != "hpvm_reload":
print("ERROR: Must load HPVM model to invoke frontend")
sys.exit(1)
hpvm_dir = "data/alexnet_cifar10/"
translate_to_approxhpvm(model, hpvm_dir, X_test, Y_test, num_classes)
### trainModel(model)
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment