Skip to content
Snippets Groups Projects
Commit fff3bae3 authored by nz11's avatar nz11
Browse files

Update resnet50_imagenet.py

parent f3627042
No related branches found
No related tags found
No related merge requests found
import os import os
import glob import glob
import random
import scipy import scipy
import scipy.io import scipy.io
import cv2 import cv2
import numpy as np import numpy as np
import tensorflow as tf
import keras import keras
from keras.models import Sequential, Model from keras.models import Sequential, Model
from keras.layers import * from keras.layers import *
from keras.applications.vgg16 import VGG16, preprocess_input from keras.applications.resnet50 import ResNet50, preprocess_input
from keras.utils import to_categorical from keras.utils import to_categorical
from keras import backend as K from keras import backend as K
...@@ -26,7 +28,7 @@ K.set_image_data_format('channels_first') ...@@ -26,7 +28,7 @@ K.set_image_data_format('channels_first')
data_format = 'channels_first' data_format = 'channels_first'
IMAGENET_DIR = '/shared/hsharif3/ILSVRC2012/' IMAGENET_DIR = '/home/nz11/ILSVRC2012/'
OUTPUT_DIR = 'data/resnet50_imagenet/' OUTPUT_DIR = 'data/resnet50_imagenet/'
NUM_CLASSES = 100 NUM_CLASSES = 100
...@@ -108,9 +110,10 @@ def get_resnet50_nchw_keras(): ...@@ -108,9 +110,10 @@ def get_resnet50_nchw_keras():
x = ZeroPadding2D((3, 3))(img_input) x = ZeroPadding2D((3, 3))(img_input)
x = Conv2D(64, (7, 7), strides=(2, 2))(x) x = Conv2D(64, (7, 7), strides=(2, 2))(x)
x = BatchNormalization(axis=bn_axis)(x) # x = BatchNormalization(axis=bn_axis)(x)
x = Activation('relu')(x) x = Activation('relu')(x)
x = MaxPooling2D((3, 3), strides=(2, 2))(x) x = MaxPooling2D((3, 3), strides=(2, 2))(x)
x = BatchNormalization(axis=bn_axis)(x)
x = conv_block(x, 3, [64, 64, 256], stage=2, block='a', strides=(1, 1)) x = conv_block(x, 3, [64, 64, 256], stage=2, block='a', strides=(1, 1))
x = identity_block(x, 3, [64, 64, 256], stage=2, block='b') x = identity_block(x, 3, [64, 64, 256], stage=2, block='b')
...@@ -138,11 +141,24 @@ def get_resnet50_nchw_keras(): ...@@ -138,11 +141,24 @@ def get_resnet50_nchw_keras():
x = Activation('softmax')(x) x = Activation('softmax')(x)
model = Model(img_input, x) model = Model(img_input, x)
original_model = ResNet50()
for i in range(len(original_model.layers)):
try:
model.layers[i].set_weights(original_model.layers[i].get_weights())
# model.layers[i].trainable = False
except:
print (i, 'skipped')
model.layers[5].set_weights(original_model.layers[3].get_weights())
return model return model
def load_image(x): def load_image(x):
image = cv2.imread(x) image = cv2.imread(x)
height, width, _ = image.shape height, width, _ = image.shape
...@@ -210,11 +226,62 @@ X_test = np.array(X_test) ...@@ -210,11 +226,62 @@ X_test = np.array(X_test)
y_true = np.array(y_true) y_true = np.array(y_true)
def train_helper(x):
try:
x = x.decode('utf-8')
except:
pass
image = load_image(x)
y = np.zeros(1000, dtype=np.uint8)
y[synset_to_keras_idx[x.split('/')[-2]]]= 1
return image, y
train_images = glob.glob(IMAGENET_DIR + 'train/*/*')
random.shuffle(train_images)
dataset = tf.data.Dataset().from_tensor_slices(train_images)
dataset = dataset.map(
lambda x : tf.py_func(train_helper, [x], [tf.float32, tf.uint8]),
num_parallel_calls=16
)
dataset = dataset.shuffle(buffer_size=1000)
dataset = dataset.batch(32)
dataset = dataset.repeat()
next_element = dataset.make_one_shot_iterator().get_next()
sess = tf.Session()
def generate():
while True:
yield sess.run(next_element)
model.compile(optimizer=keras.optimizers.Adam(lr=0.00001), loss='categorical_crossentropy', metrics=['acc'])
model.fit_generator(generate(), steps_per_epoch=1000, validation_data=(X_test, to_categorical(y_true, num_classes=1000)), epochs=6)
translate_to_approxhpvm(model, OUTPUT_DIR, X_test[:VAL_SIZE], y_true[:VAL_SIZE], 1000) translate_to_approxhpvm(model, OUTPUT_DIR, X_test[:VAL_SIZE], y_true[:VAL_SIZE], 1000)
dumpCalibrationData(OUTPUT_DIR + 'test_input.bin', X_test, OUTPUT_DIR + 'test_labels.bin', y_true) dumpCalibrationData(OUTPUT_DIR + 'test_input.bin', X_test, OUTPUT_DIR + 'test_labels.bin', y_true)
# pred = np.argmax(model_nchw.predict(X_test), axis=1) # pred = np.argmax(model.predict(X_test), axis=1)
# print ('val accuracy', np.sum(pred == y_true.ravel()) / len(X_test)) # print ('val accuracy', np.sum(pred == y_true.ravel()) / len(X_test))
\ No newline at end of file
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment