from typing import Iterable from torch.nn import Linear, ReLU, Sequential from ._container import Classifier, make_conv_pool_activ class _VGG16(Classifier): def __init__(self, linear_inouts: Iterable[int]): convs = Sequential( *make_conv_pool_activ(3, 64, 3, ReLU, padding=1), *make_conv_pool_activ(64, 64, 3, ReLU, 2, padding=1), *make_conv_pool_activ(64, 128, 3, ReLU, padding=1), *make_conv_pool_activ(128, 128, 3, ReLU, 2, padding=1), *make_conv_pool_activ(128, 256, 3, ReLU, padding=1), *make_conv_pool_activ(256, 256, 3, ReLU, padding=1), *make_conv_pool_activ(256, 256, 3, ReLU, 2, padding=1), *make_conv_pool_activ(256, 512, 3, ReLU, padding=1), *make_conv_pool_activ(512, 512, 3, ReLU, padding=1), *make_conv_pool_activ(512, 512, 3, ReLU, 2, padding=1), *make_conv_pool_activ(512, 512, 3, ReLU, padding=1), *make_conv_pool_activ(512, 512, 3, ReLU, padding=1), *make_conv_pool_activ(512, 512, 3, ReLU, 2, padding=1) ) linear_layers = [ Linear(in_, out) for in_, out in zip(linear_inouts, linear_inouts[1:]) ] linear_relus = [ReLU() for _ in range(2 * len(linear_layers) - 1)] linear_relus[::2] = linear_layers linears = Sequential(*linear_relus) super().__init__(convs, linears) class VGG16Cifar10(_VGG16): def __init__(self): super().__init__([512, 512, 10]) class VGG16Cifar100(_VGG16): def __init__(self): super().__init__([512, 512, 100])