Newer
Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
from typing import Iterable
from torch.nn import Linear, ReLU, Sequential
from ._container import Classifier, make_conv_pool_activ
class _VGG16(Classifier):
def __init__(self, linear_inouts: Iterable[int]):
convs = Sequential(
*make_conv_pool_activ(3, 64, 3, ReLU, padding=1),
*make_conv_pool_activ(64, 64, 3, ReLU, 2, padding=1),
*make_conv_pool_activ(64, 128, 3, ReLU, padding=1),
*make_conv_pool_activ(128, 128, 3, ReLU, 2, padding=1),
*make_conv_pool_activ(128, 256, 3, ReLU, padding=1),
*make_conv_pool_activ(256, 256, 3, ReLU, padding=1),
*make_conv_pool_activ(256, 256, 3, ReLU, 2, padding=1),
*make_conv_pool_activ(256, 512, 3, ReLU, padding=1),
*make_conv_pool_activ(512, 512, 3, ReLU, padding=1),
*make_conv_pool_activ(512, 512, 3, ReLU, 2, padding=1),
*make_conv_pool_activ(512, 512, 3, ReLU, padding=1),
*make_conv_pool_activ(512, 512, 3, ReLU, padding=1),
*make_conv_pool_activ(512, 512, 3, ReLU, 2, padding=1)
)
linear_layers = [
Linear(in_, out) for in_, out in zip(linear_inouts, linear_inouts[1:])
]
linear_relus = [ReLU() for _ in range(2 * len(linear_layers) - 1)]
linear_relus[::2] = linear_layers
linears = Sequential(*linear_relus)
super().__init__(convs, linears)
class VGG16Cifar10(_VGG16):
def __init__(self):
super().__init__([512, 512, 10])
class VGG16Cifar100(_VGG16):
def __init__(self):
super().__init__([512, 512, 100])