Skip to content
Snippets Groups Projects
Commit 504ad4e9 authored by Nathan Zhao's avatar Nathan Zhao
Browse files

fix imagenet benchmark bugs

parent 8fbacdab
No related branches found
No related tags found
No related merge requests found
......@@ -21,28 +21,19 @@ from Config import MODEL_PARAMS_DIR
IMAGENET_DIR = '/home/nz11/ILSVRC2012/'
NUM_TUNE_CLASSES = 200
IMAGES_PER_CLASS = 50
class AlexNet(Benchmark):
def data_preprocess(self):
X_val = np.fromfile(MODEL_PARAMS_DIR + '/alexnet_imagenet/test_input.bin', dtype=np.float32)
y_val = np.fromfile(MODEL_PARAMS_DIR + '/alexnet_imagenet/test_labels.bin', dtype=np.uint32)
X_val = X_val.reshape((-1, 3, 224, 224))
X_train, y_train = None, None
X_test = X_val[0:5000]
y_test = y_val[0:5000]
X_tuner = X_val[5000:]
y_tuner = y_val[5000:]
X_test = np.fromfile(MODEL_PARAMS_DIR + '/alexnet_imagenet/test_input.bin', dtype=np.float32)
X_test = X_test.reshape((-1, 3, 224, 224))
y_test = np.fromfile(MODEL_PARAMS_DIR + '/alexnet_imagenet/test_labels.bin', dtype=np.uint32)
X_tuner = np.fromfile(MODEL_PARAMS_DIR + '/alexnet_imagenet/tune_input.bin', dtype=np.float32)
X_tuner = X_tuner.reshape((-1, 3, 224, 224))
y_tuner = np.fromfile(MODEL_PARAMS_DIR + '/alexnet_imagenet/tune_labels.bin', dtype=np.uint32)
return X_train, y_train, X_test, y_test, X_tuner, y_tuner
......
......@@ -51,7 +51,6 @@ class LeNet_MNIST(Benchmark):
def data_preprocess(self):
(X_train, y_train), (X_val, y_val) = mnist.load_data()
test_labels = y_val
......
......@@ -16,18 +16,11 @@ from keras.utils import to_categorical
from keras.preprocessing.image import ImageDataGenerator
from keras.callbacks import LearningRateScheduler
from keras.applications.resnet50 import preprocess_input
from Benchmark import Benchmark
from Config import MODEL_PARAMS_DIR
IMAGENET_DIR = '/home/nz11/ILSVRC2012/'
NUM_TUNE_CLASSES = 200
IMAGES_PER_CLASS = 50
class ResNet50(Benchmark):
def buildModel(self):
......@@ -120,19 +113,16 @@ class ResNet50(Benchmark):
def data_preprocess(self):
X_val = np.fromfile(MODEL_PARAMS_DIR + '/resnet50_imagenet/test_input.bin', dtype=np.float32)
y_val = np.fromfile(MODEL_PARAMS_DIR + '/resnet50_imagenet/test_labels.bin', dtype=np.uint32)
X_val = X_val.reshape((-1, 3, 224, 224))
X_train, y_train = None, None
X_test = X_val[0:5000]
y_test = y_val[0:5000]
X_tuner = X_val[5000:]
y_tuner = y_val[5000:]
X_test = np.fromfile(MODEL_PARAMS_DIR + '/resnet50_imagenet/test_input.bin', dtype=np.float32)
X_test = X_test.reshape((-1, 3, 224, 224))
y_test = np.fromfile(MODEL_PARAMS_DIR + '/resnet50_imagenet/test_labels.bin', dtype=np.uint32)
X_tuner = np.fromfile(MODEL_PARAMS_DIR + '/resnet50_imagenet/tune_input.bin', dtype=np.float32)
X_tuner = X_tuner.reshape((-1, 3, 224, 224))
y_tuner = np.fromfile(MODEL_PARAMS_DIR + '/resnet50_imagenet/tune_labels.bin', dtype=np.uint32)
return X_train, y_train, X_test, y_test, X_tuner, y_tuner
......
......@@ -21,12 +21,6 @@ from Config import MODEL_PARAMS_DIR
IMAGENET_DIR = '/home/nz11/ILSVRC2012/'
NUM_TUNE_CLASSES = 200
IMAGES_PER_CLASS = 50
class VGG16(Benchmark):
def buildModel(self):
......@@ -104,19 +98,16 @@ class VGG16(Benchmark):
def data_preprocess(self):
X_val = np.fromfile(MODEL_PARAMS_DIR + '/vgg16_imagenet/test_input.bin', dtype=np.float32)
y_val = np.fromfile(MODEL_PARAMS_DIR + '/vgg16_imagenet/test_labels.bin', dtype=np.uint32)
X_val = X_val.reshape((-1, 3, 224, 224))
X_train, y_train = None, None
X_test = X_val[0:5000]
y_test = y_val[0:5000]
X_tuner = X_val[5000:]
y_tuner = y_val[5000:]
X_test = np.fromfile(MODEL_PARAMS_DIR + '/vgg16_imagenet/test_input.bin', dtype=np.float32)
X_test = X_test.reshape((-1, 3, 224, 224))
y_test = np.fromfile(MODEL_PARAMS_DIR + '/vgg16_imagenet/test_labels.bin', dtype=np.uint32)
X_tuner = np.fromfile(MODEL_PARAMS_DIR + '/vgg16_imagenet/tune_input.bin', dtype=np.float32)
X_tuner = X_tuner.reshape((-1, 3, 224, 224))
y_tuner = np.fromfile(MODEL_PARAMS_DIR + '/vgg16_imagenet/tune_labels.bin', dtype=np.uint32)
return X_train, y_train, X_test, y_test, X_tuner, y_tuner
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment