Newer
Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
from torch.nn import AvgPool2d, BatchNorm2d, Conv2d, Linear, ReLU, Sequential
from ._container import Classifier, make_conv_pool_activ
def _make_seq(in_channels, out_channels, c_kernel_size, gc_stride, gc_kernel_size=3):
return Sequential(
*make_conv_pool_activ(
in_channels,
out_channels,
c_kernel_size,
bias=False,
padding=(c_kernel_size - 1) // 2,
),
BatchNorm2d(out_channels, eps=0.001),
ReLU(),
Conv2d(
out_channels,
out_channels,
gc_kernel_size,
bias=False,
stride=gc_stride,
padding=(gc_kernel_size - 1) // 2,
groups=out_channels,
),
BatchNorm2d(out_channels, eps=0.001),
ReLU()
)
class MobileNet(Classifier):
def __init__(self):
convs = Sequential(
_make_seq(3, 32, 3, 1),
_make_seq(32, 64, 1, 2),
_make_seq(64, 128, 1, 1),
_make_seq(128, 128, 1, 2),
_make_seq(128, 256, 1, 1),
_make_seq(256, 256, 1, 2),
_make_seq(256, 512, 1, 1),
_make_seq(512, 512, 1, 1),
_make_seq(512, 512, 1, 1),
_make_seq(512, 512, 1, 1),
_make_seq(512, 512, 1, 1),
_make_seq(512, 512, 1, 2),
_make_seq(512, 1024, 1, 1),
*make_conv_pool_activ(1024, 1024, 1, padding=0, bias=False),
BatchNorm2d(1024, eps=0.001),
ReLU(),
AvgPool2d(2)
)
linears = Sequential(Linear(1024, 10))
super().__init__(convs, linears)