Skip to content
Snippets Groups Projects
Commit 6247a3da authored by nz11's avatar nz11
Browse files

Add new file

parent 7ea016f0
No related branches found
No related tags found
No related merge requests found
import os
os.environ['CUDA_VISIBLE_DEVICES'] = '0'
from keras.models import Sequential
from keras.layers import *
from keras.datasets import cifar10
from keras.utils import to_categorical
from keras.callbacks import *
from keras.preprocessing.image import ImageDataGenerator
from keras.models import Model
from keras import optimizers
import keras.backend as K
(X_train, y_train), (X_test, y_test) = cifar10.load_data()
X_train = X_train.astype('float32')
X_test = X_test.astype('float32')
mean = np.mean(X_train, axis=(0, 1, 2), keepdims=True)
std = np.std(X_train, axis=(0, 1, 2), keepdims=True)
X_train = (X_train - mean) / (std + 1e-9)
X_test = (X_test - mean) / (std + 1e-9)
y_train = to_categorical(y_train, num_classes=10)
y_test = to_categorical(y_test, num_classes=10)
def get_mobilenet(alpha=1, depth_multiplier=1):
model = Sequential()
def _conv_block(filters, alpha, kernel=(3, 3), strides=(1, 1)):
channel_axis = 1 if K.image_data_format() == 'channels_first' else -1
filters = int(filters * alpha)
model.add(Conv2D(filters, kernel,
padding='same',
use_bias=False,
strides=strides,
input_shape=(32, 32, 3)))
model.add(BatchNormalization(axis=channel_axis))
model.add(ReLU())
def _depthwise_conv_block(pointwise_conv_filters, alpha, depth_multiplier=1, strides=(1, 1)):
channel_axis = 1 if K.image_data_format() == 'channels_first' else -1
pointwise_conv_filters = int(pointwise_conv_filters * alpha)
model.add(DepthwiseConv2D((3, 3),
padding='same',
depth_multiplier=depth_multiplier,
strides=strides,
use_bias=False))
model.add(BatchNormalization(axis=channel_axis))
model.add(ReLU())
model.add(Conv2D(pointwise_conv_filters, (1, 1),
padding='same',
use_bias=False,
strides=(1, 1)))
model.add(BatchNormalization(axis=channel_axis))
model.add(ReLU())
_conv_block(32, alpha, strides=(1, 1))
_depthwise_conv_block(64, alpha, depth_multiplier)
_depthwise_conv_block(128, alpha, depth_multiplier,
strides=(2, 2))
_depthwise_conv_block(128, alpha, depth_multiplier)
model.add(Dropout(rate=0.5))
_depthwise_conv_block(256, alpha, depth_multiplier,
strides=(2, 2))
_depthwise_conv_block(256, alpha, depth_multiplier)
model.add(Dropout(rate=0.5))
_depthwise_conv_block(512, alpha, depth_multiplier,
strides=(2, 2))
_depthwise_conv_block(512, alpha, depth_multiplier)
_depthwise_conv_block(512, alpha, depth_multiplier)
model.add(Dropout(rate=0.5))
_depthwise_conv_block(512, alpha, depth_multiplier)
_depthwise_conv_block(512, alpha, depth_multiplier)
_depthwise_conv_block(512, alpha, depth_multiplier)
model.add(Dropout(rate=0.5))
_depthwise_conv_block(1024, alpha, depth_multiplier,
strides=(2, 2))
_depthwise_conv_block(1024, alpha, depth_multiplier)
model.add(Dropout(rate=0.5))
model.add(GlobalAveragePooling2D())
model.add(Dense(10, activation='softmax'))
return model
# data augmentation, horizontal flips only
datagen = ImageDataGenerator(
featurewise_center=False,
featurewise_std_normalization=False,
rotation_range=0.0,
width_shift_range=0.0,
height_shift_range=0.0,
vertical_flip=False,
horizontal_flip=True)
datagen.fit(X_train)
model = get_mobilenet()
learning_rates=[]
for i in range(5):
learning_rates.append(2e-2)
for i in range(50-5):
learning_rates.append(1e-2)
for i in range(100-50):
learning_rates.append(8e-3)
for i in range(150-100):
learning_rates.append(4e-3)
for i in range(200-150):
learning_rates.append(2e-3)
for i in range(300-200):
learning_rates.append(1e-3)
callbacks = [
LearningRateScheduler(lambda epoch: float(learning_rates[epoch]))
]
model.compile(optimizer=optimizers.SGD(lr=learning_rates[0], momentum=0.9, decay=0.0, nesterov=False),
loss='categorical_crossentropy',
metrics=['accuracy'])
model.fit_generator(
datagen.flow(X_train, y_train, batch_size=128),
steps_per_epoch=int(np.ceil(50000 / 128)),
validation_data=(X_test, y_test),
epochs=300,
callbacks=callbacks
)
model.save_weights('mobilenet_cifar10.h5')
\ No newline at end of file
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment