Skip to content
Snippets Groups Projects
Commit 654b8ff8 authored by nz11's avatar nz11
Browse files

Update mobilenetv2_cifar10.py

parent d8208d32
No related branches found
No related tags found
No related merge requests found
......@@ -44,9 +44,9 @@ def _make_divisible(v, divisor, min_value=None):
return new_v
def _inverted_res_block(inputs, expansion, stride, alpha, filters, block_id):
channel_axis = 1 if K.image_data_format() == 'channels_first' else -1
channel_axis = 1
in_channels = inputs.shape[1] if K.image_data_format() == 'channels_first' else inputs.shape[-1]
in_channels = inputs.shape[1]
pointwise_conv_filters = int(filters * alpha)
pointwise_filters = _make_divisible(pointwise_conv_filters, 8)
x = inputs
......@@ -57,9 +57,9 @@ def _inverted_res_block(inputs, expansion, stride, alpha, filters, block_id):
x = Activation('relu')(x)
if stride == 2:
x = ZeroPadding2D(padding=((0, 1), (0, 1)))(x)
x = ZeroPadding2D(padding=(1, 1))(x)
else:
x = ZeroPadding2D(padding=((1, 1), (1, 1)))(x)
x = ZeroPadding2D(padding=(1, 1))(x)
x = DepthwiseConv2D(kernel_size=3, strides=stride, use_bias=False, padding='valid')(x)
x = BatchNormalization(axis=channel_axis)(x)
......@@ -75,12 +75,12 @@ def _inverted_res_block(inputs, expansion, stride, alpha, filters, block_id):
def get_mobilenetv2(alpha=1.0, depth_multiplier=1):
channel_axis = 1 if K.image_data_format() == 'channels_first' else -1
channel_axis = 1
first_block_filters = _make_divisible(32 * alpha, 8)
img_input = Input(shape=(3, 32, 32))
x = ZeroPadding2D(padding=((0, 2), (0, 2)))(img_input)
x = ZeroPadding2D(padding=(1, 1))(img_input)
x = Conv2D(first_block_filters, kernel_size=3, strides=1, padding='valid', use_bias=False)(x)
#x = BatchNormalization(axis=channel_axis)(x)
#x = Activation('relu')(x)
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment