Skip to content
Snippets Groups Projects
Commit d58f0fd2 authored by nz11's avatar nz11
Browse files

Update alexnet_imagenet.py

parent 0466d6f8
No related branches found
No related tags found
No related merge requests found
...@@ -30,10 +30,10 @@ data_format = 'channels_first' ...@@ -30,10 +30,10 @@ data_format = 'channels_first'
IMAGENET_DIR = '/home/nz11/ILSVRC2012/' IMAGENET_DIR = '/home/nz11/ILSVRC2012/'
OUTPUT_DIR = 'data/alexnet_imagenet_tune/' OUTPUT_DIR = 'data/alexnet_imagenet_tune/'
WEIGHTS_PATH = 'data/weights.h5' WEIGHTS_PATH = 'data/alexnet_imagenet_tune/weights.h5'
NUM_CLASSES = 200 NUM_CLASSES = 200
IMAGES_PER_CLASS = 40 IMAGES_PER_CLASS = 50
# VAL_SIZE = 100 # VAL_SIZE = 100
...@@ -183,6 +183,9 @@ y_true = np.array(y_true) ...@@ -183,6 +183,9 @@ y_true = np.array(y_true)
X_tune = np.array(X_tune) X_tune = np.array(X_tune)
y_tune = np.array(y_tune) y_tune = np.array(y_tune)
print ('tune size', len(X_tune))
print ('test size', len(X_test))
...@@ -233,16 +236,16 @@ model.compile(optimizer=keras.optimizers.Adam(lr=0.00001), loss='categorical_cro ...@@ -233,16 +236,16 @@ model.compile(optimizer=keras.optimizers.Adam(lr=0.00001), loss='categorical_cro
if os.path.exists(WEIGHTS_PATH): if os.path.exists(WEIGHTS_PATH):
model.load_weights(WEIGHTS_PATH) model.load_weights(WEIGHTS_PATH)
else: else:
model.fit_generator(generate(), steps_per_epoch=1000, validation_data=(X_test, to_categorical(y_true, num_classes=1000)), epochs=2) pass
K.set_value(model.optimizer.lr, 0.000001) # model.fit_generator(generate(), steps_per_epoch=1000, validation_data=(X_test, to_categorical(y_true, num_classes=1000)), epochs=3)
model.fit_generator(generate(), steps_per_epoch=1000, validation_data=(X_test, to_categorical(y_true, num_classes=1000)), epochs=6) # K.set_value(model.optimizer.lr, 0.000001)
model.save_weights('data/weights.h5') # model.fit_generator(generate(), steps_per_epoch=1000, validation_data=(X_test, to_categorical(y_true, num_classes=1000)), epochs=3)
translate_to_approxhpvm(model, OUTPUT_DIR, X_tune, y_tune, 1000) translate_to_approxhpvm(model, OUTPUT_DIR, X_tune, y_tune, 1000, dump_weights=False)
# dumpCalibrationData2(OUTPUT_DIR + 'test_input_10K.bin', X_test, OUTPUT_DIR + 'test_labels_10K.bin', y_true) # # dumpCalibrationData2(OUTPUT_DIR + 'test_input_10K.bin', X_test, OUTPUT_DIR + 'test_labels_10K.bin', y_true)
dumpCalibrationData2(OUTPUT_DIR + 'tune_input.bin', X_tune, OUTPUT_DIR + 'tune_labels.bin', y_tune) # dumpCalibrationData2(OUTPUT_DIR + 'tune_input.bin', X_tune, OUTPUT_DIR + 'tune_labels.bin', y_tune)
dumpCalibrationData2(OUTPUT_DIR + 'test_input.bin', X_test, OUTPUT_DIR + 'test_labels.bin', y_true) # dumpCalibrationData2(OUTPUT_DIR + 'test_input.bin', X_test, OUTPUT_DIR + 'test_labels.bin', y_true)
pred = np.argmax(model.predict(X_test), axis=1) pred = np.argmax(model.predict(X_test), axis=1)
...@@ -250,4 +253,7 @@ print ('val accuracy', np.sum(pred == y_true.ravel()) / len(X_test)) ...@@ -250,4 +253,7 @@ print ('val accuracy', np.sum(pred == y_true.ravel()) / len(X_test))
pred = np.argmax(model.predict(X_tune), axis=1) pred = np.argmax(model.predict(X_tune), axis=1)
print ('val accuracy', np.sum(pred == y_tune.ravel()) / len(X_tune)) print ('val accuracy', np.sum(pred == y_tune.ravel()) / len(X_tune))
model.save_weights(OUTPUT_DIR + '/weights.h5')
\ No newline at end of file
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment