Skip to content
Snippets Groups Projects
Commit e0a9b1b3 authored by nz11's avatar nz11
Browse files

Update mobilenet_shallow.py

parent cbfa3cab
No related branches found
No related tags found
No related merge requests found
......@@ -41,18 +41,18 @@ def get_mobilenet(alpha=1, depth_multiplier=1):
model = Sequential()
def _conv_block(filters, alpha, kernel=(3, 3), strides=(1, 1)):
channel_axis = 1 if K.image_data_format() == 'channels_first' else -1
channel_axis = 1
filters = int(filters * alpha)
model.add(Conv2D(filters, kernel,
padding='same',
use_bias=False,
strides=strides,
input_shape=(32, 32, 3)))
input_shape=(3, 32, 32)))
model.add(BatchNormalization(axis=channel_axis))
model.add(ReLU())
model.add(Activation('relu'))
def _depthwise_conv_block(pointwise_conv_filters, alpha, depth_multiplier=1, strides=(1, 1)):
channel_axis = 1 if K.image_data_format() == 'channels_first' else -1
channel_axis = 1
pointwise_conv_filters = int(pointwise_conv_filters * alpha)
model.add(DepthwiseConv2D((3, 3),
......@@ -61,13 +61,13 @@ def get_mobilenet(alpha=1, depth_multiplier=1):
strides=strides,
use_bias=False))
model.add(BatchNormalization(axis=channel_axis))
model.add(ReLU())
model.add(Activation('relu'))
model.add(Conv2D(pointwise_conv_filters, (1, 1),
padding='same',
use_bias=False,
strides=(1, 1)))
model.add(BatchNormalization(axis=channel_axis))
model.add(ReLU())
model.add(Activation('relu'))
_conv_block(32, alpha, strides=(1, 1))
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment