Skip to content
Snippets Groups Projects
vgg16.py 1.58 KiB
Newer Older
  • Learn to ignore specific revisions
  • from typing import Iterable
    
    from torch.nn import Linear, ReLU, Sequential
    
    from ._container import Classifier, make_conv_pool_activ
    
    
    class _VGG16(Classifier):
        def __init__(self, linear_inouts: Iterable[int]):
            convs = Sequential(
                *make_conv_pool_activ(3, 64, 3, ReLU, padding=1),
                *make_conv_pool_activ(64, 64, 3, ReLU, 2, padding=1),
                *make_conv_pool_activ(64, 128, 3, ReLU, padding=1),
                *make_conv_pool_activ(128, 128, 3, ReLU, 2, padding=1),
                *make_conv_pool_activ(128, 256, 3, ReLU, padding=1),
                *make_conv_pool_activ(256, 256, 3, ReLU, padding=1),
                *make_conv_pool_activ(256, 256, 3, ReLU, 2, padding=1),
                *make_conv_pool_activ(256, 512, 3, ReLU, padding=1),
                *make_conv_pool_activ(512, 512, 3, ReLU, padding=1),
                *make_conv_pool_activ(512, 512, 3, ReLU, 2, padding=1),
                *make_conv_pool_activ(512, 512, 3, ReLU, padding=1),
                *make_conv_pool_activ(512, 512, 3, ReLU, padding=1),
                *make_conv_pool_activ(512, 512, 3, ReLU, 2, padding=1)
            )
            linear_layers = [
                Linear(in_, out) for in_, out in zip(linear_inouts, linear_inouts[1:])
            ]
            linear_relus = [ReLU() for _ in range(2 * len(linear_layers) - 1)]
            linear_relus[::2] = linear_layers
            linears = Sequential(*linear_relus)
            super().__init__(convs, linears)
    
    
    class VGG16Cifar10(_VGG16):
        def __init__(self):
            super().__init__([512, 512, 10])
    
    
    class VGG16Cifar100(_VGG16):
        def __init__(self):
            super().__init__([512, 512, 100])