Skip to content
Snippets Groups Projects
Commit 48fce727 authored by nz11's avatar nz11
Browse files

Add new file

parent a39dd418
No related branches found
No related tags found
No related merge requests found
import os
import glob
import random
import scipy
import scipy.io
import cv2
import numpy as np
import tensorflow as tf
import keras
from keras.models import Sequential, Model
from keras.layers import *
from keras.utils import to_categorical
from keras import backend as K
import torchvision.models as models
from frontend.approxhpvm_translator import translate_to_approxhpvm
from frontend.weight_utils import dumpCalibrationData2
np.random.seed(2020)
os.environ["CUDA_VISIBLE_DEVICES"] = "1"
K.set_image_data_format('channels_first')
data_format = 'channels_first'
IMAGENET_DIR = '/home/nz11/ILSVRC2012/'
OUTPUT_DIR = 'data/alexnet_imagenet_tune/'
WEIGHTS_PATH = 'data/weights.h5'
NUM_CLASSES = 200
IMAGES_PER_CLASS = 40
# VAL_SIZE = 100
def get_alexnet_nchw_keras():
input_layer = Input((3, 224, 224))
x = ZeroPadding2D((2, 2))(input_layer)
x = Conv2D(64, (11, 11), strides=4, padding='valid')(x)
x = Activation('relu')(x)
x = MaxPooling2D(3, 2)(x)
x = ZeroPadding2D((2, 2))(x)
x = Conv2D(192, (5, 5), padding='valid')(x)
x = Activation('relu')(x)
x = MaxPooling2D(3, 2)(x)
x = Conv2D(384, (3, 3), padding='same')(x)
x = Activation('relu')(x)
x = Conv2D(256, (3, 3), padding='same')(x)
x = Activation('relu')(x)
x = Conv2D(256, (3, 3), padding='same')(x)
x = Activation('relu')(x)
x = MaxPooling2D(3, 2)(x)
x = Flatten()(x)
x = Dropout(0.5)(x)
x = Dense(4096)(x)
x = Activation('relu')(x)
x = Dropout(0.5)(x)
x = Dense(4096)(x)
x = Activation('relu')(x)
x = Dense(1000)(x)
x = Activation('softmax')(x)
model_nchw = Model(input_layer, x)
torch_model = models.alexnet(pretrained=True)
j = 0
torch_weights = list(torch_model.parameters())
for i in range(len(model_nchw.layers)):
if (2 * j >= len(torch_weights)):
break
w = torch_weights[2 * j].detach().numpy()
b = torch_weights[2 * j + 1].detach().numpy()
if (len(w.shape) == 4):
w = np.transpose(w, (2, 3, 1, 0))
else:
w = w.transpose()
try:
model_nchw.layers[i].set_weights([w, b])
j += 1
print ([w.shape, b.shape], 'loaded')
except:
pass
return model_nchw
def load_image(x):
image = cv2.imread(x)
height, width, _ = image.shape
new_height = height * 256 // min(image.shape[:2])
new_width = width * 256 // min(image.shape[:2])
image = cv2.resize(image, (new_width, new_height), interpolation=cv2.INTER_CUBIC)
height, width, _ = image.shape
startx = width // 2 - (224 // 2)
starty = height // 2 - (224 // 2)
image = image[starty:starty + 224, startx:startx + 224]
image = image[:, :, ::-1]
image = np.transpose(image, (2, 0, 1))
image[:, :, 0] = (image[:, :, 0] - 0.485) / 0.229
image[:, :, 1] = (image[:, :, 1] - 0.456) / 0.224
image[:, :, 2] = (image[:, :, 2] - 0.406) / 0.225
return image.astype(np.float32)
meta = scipy.io.loadmat(IMAGENET_DIR + 'ILSVRC2012_devkit_t12/data/meta.mat')
original_idx_to_synset = {}
synset_to_name = {}
for i in range(1000):
ilsvrc2012_id = int(meta['synsets'][i,0][0][0][0])
synset = meta['synsets'][i,0][1][0]
name = meta['synsets'][i,0][2][0]
original_idx_to_synset[ilsvrc2012_id] = synset
synset_to_name[synset] = name
synset_to_keras_idx = {}
keras_idx_to_name = {}
f = open(IMAGENET_DIR + 'ILSVRC2012_devkit_t12/data/synset_words.txt', 'r')
c = 0
for line in f:
parts = line.split(' ')
synset_to_keras_idx[parts[0]] = c
keras_idx_to_name[c] = ' '.join(parts[1:])
c += 1
f.close()
model = get_alexnet_nchw_keras()
X_tune, X_test = [], []
y_tune, y_true = [], []
classes = glob.glob(IMAGENET_DIR + 'val/*')
for c in np.random.permutation(len(classes))[:NUM_CLASSES]:
x = glob.glob(classes[c] + '/*')
x = np.array(x)
idx = np.random.permutation(len(x))
idx = idx[:max(len(idx), IMAGES_PER_CLASS)]
synset = classes[c].split('/')[-1]
images = list(map(lambda x : load_image(x), x[idx]))
labels = [synset_to_keras_idx[synset]] * len(x[idx])
X_test += images[:IMAGES_PER_CLASS // 2]
y_true += labels[:IMAGES_PER_CLASS // 2]
X_tune += images[IMAGES_PER_CLASS // 2:]
y_tune += labels[IMAGES_PER_CLASS // 2:]
X_test = np.array(X_test)
y_true = np.array(y_true)
X_tune = np.array(X_tune)
y_tune = np.array(y_tune)
def train_helper(x):
try:
x = x.decode('utf-8')
except:
pass
image = load_image(x)
y = np.zeros(1000, dtype=np.uint8)
y[synset_to_keras_idx[x.split('/')[-2]]] = 1
return image, y
train_images = glob.glob(IMAGENET_DIR + 'train/*/*')
random.shuffle(train_images)
dataset = tf.data.Dataset().from_tensor_slices(train_images)
dataset = dataset.map(
lambda x : tf.py_func(train_helper, [x], [tf.float32, tf.uint8]),
num_parallel_calls=16
)
dataset = dataset.shuffle(buffer_size=1000)
dataset = dataset.batch(64)
dataset = dataset.repeat()
next_element = dataset.make_one_shot_iterator().get_next()
sess = tf.Session()
def generate():
while True:
yield sess.run(next_element)
model.compile(optimizer=keras.optimizers.Adam(lr=0.00001), loss='categorical_crossentropy', metrics=['acc'])
if os.path.exists(WEIGHTS_PATH):
model.load_weights(WEIGHTS_PATH)
else:
model.fit_generator(generate(), steps_per_epoch=1000, validation_data=(X_test, to_categorical(y_true, num_classes=1000)), epochs=2)
K.set_value(model.optimizer.lr, 0.000001)
model.fit_generator(generate(), steps_per_epoch=1000, validation_data=(X_test, to_categorical(y_true, num_classes=1000)), epochs=6)
model.save_weights('data/weights.h5')
translate_to_approxhpvm(model, OUTPUT_DIR, X_tune, y_tune, 1000)
# dumpCalibrationData2(OUTPUT_DIR + 'test_input_10K.bin', X_test, OUTPUT_DIR + 'test_labels_10K.bin', y_true)
dumpCalibrationData2(OUTPUT_DIR + 'tune_input.bin', X_tune, OUTPUT_DIR + 'tune_labels.bin', y_tune)
dumpCalibrationData2(OUTPUT_DIR + 'test_input.bin', X_test, OUTPUT_DIR + 'test_labels.bin', y_true)
pred = np.argmax(model.predict(X_test), axis=1)
print ('val accuracy', np.sum(pred == y_true.ravel()) / len(X_test))
pred = np.argmax(model.predict(X_tune), axis=1)
print ('val accuracy', np.sum(pred == y_tune.ravel()) / len(X_tune))
\ No newline at end of file
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment