Skip to content
Snippets Groups Projects
Commit 60cd77c9 authored by nz11's avatar nz11
Browse files

changed to channels first

parent 9f5ff44f
No related branches found
No related tags found
No related merge requests found
...@@ -11,6 +11,8 @@ from keras.models import Model ...@@ -11,6 +11,8 @@ from keras.models import Model
from keras import optimizers from keras import optimizers
import keras.backend as K import keras.backend as K
K.set_image_dim_ordering('th')
(X_train, y_train), (X_test, y_test) = cifar10.load_data() (X_train, y_train), (X_test, y_test) = cifar10.load_data()
...@@ -22,6 +24,9 @@ std = np.std(X_train, axis=(0, 1, 2), keepdims=True) ...@@ -22,6 +24,9 @@ std = np.std(X_train, axis=(0, 1, 2), keepdims=True)
X_train = (X_train - mean) / (std + 1e-9) X_train = (X_train - mean) / (std + 1e-9)
X_test = (X_test - mean) / (std + 1e-9) X_test = (X_test - mean) / (std + 1e-9)
np.rollaxis(X_train, 3, 1)
np.rollaxis(X_test, 3, 1)
y_train = to_categorical(y_train, num_classes=10) y_train = to_categorical(y_train, num_classes=10)
y_test = to_categorical(y_test, num_classes=10) y_test = to_categorical(y_test, num_classes=10)
...@@ -36,9 +41,9 @@ def get_mobilenet(alpha=1, depth_multiplier=1): ...@@ -36,9 +41,9 @@ def get_mobilenet(alpha=1, depth_multiplier=1):
padding='same', padding='same',
use_bias=False, use_bias=False,
strides=strides, strides=strides,
input_shape=(32, 32, 3))) input_shape=(3, 32, 32)))
model.add(BatchNormalization(axis=channel_axis)) model.add(BatchNormalization(axis=channel_axis))
model.add(ReLU()) model.add(Activation('relu'))
def _depthwise_conv_block(pointwise_conv_filters, alpha, depth_multiplier=1, strides=(1, 1)): def _depthwise_conv_block(pointwise_conv_filters, alpha, depth_multiplier=1, strides=(1, 1)):
channel_axis = 1 if K.image_data_format() == 'channels_first' else -1 channel_axis = 1 if K.image_data_format() == 'channels_first' else -1
...@@ -50,13 +55,13 @@ def get_mobilenet(alpha=1, depth_multiplier=1): ...@@ -50,13 +55,13 @@ def get_mobilenet(alpha=1, depth_multiplier=1):
strides=strides, strides=strides,
use_bias=False)) use_bias=False))
model.add(BatchNormalization(axis=channel_axis)) model.add(BatchNormalization(axis=channel_axis))
model.add(ReLU()) model.add(Activation('relu'))
model.add(Conv2D(pointwise_conv_filters, (1, 1), model.add(Conv2D(pointwise_conv_filters, (1, 1),
padding='same', padding='same',
use_bias=False, use_bias=False,
strides=(1, 1))) strides=(1, 1)))
model.add(BatchNormalization(axis=channel_axis)) model.add(BatchNormalization(axis=channel_axis))
model.add(ReLU()) model.add(Activation('relu'))
_conv_block(32, alpha, strides=(1, 1)) _conv_block(32, alpha, strides=(1, 1))
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment