Skip to content
Snippets Groups Projects
alexnet.py 1.97 KiB
Newer Older
  • Learn to ignore specific revisions
  • from torch.nn import Linear, ReLU, Sequential, Tanh
    
    from ._container import Classifier, make_conv_pool_activ
    
    
    class AlexNet(Classifier):
        def __init__(self):
            convs = Sequential(
                *make_conv_pool_activ(3, 64, 11, Tanh, pool_size=2, padding=5),
                *make_conv_pool_activ(64, 192, 5, Tanh, pool_size=2, padding=2),
                *make_conv_pool_activ(192, 384, 3, Tanh, padding=1),
                *make_conv_pool_activ(384, 256, 3, Tanh, padding=1),
                *make_conv_pool_activ(256, 256, 3, Tanh, pool_size=2, padding=1)
            )
            linears = Sequential(Linear(4096, 10))
            super().__init__(convs, linears)
    
    
    class AlexNet2(Classifier):
        def __init__(self):
            convs = Sequential(
                *make_conv_pool_activ(3, 32, 3, Tanh, padding=1),
                *make_conv_pool_activ(32, 32, 3, Tanh, pool_size=2, padding=1),
                *make_conv_pool_activ(32, 64, 3, Tanh, padding=1),
                *make_conv_pool_activ(64, 64, 3, Tanh, pool_size=2, padding=1),
                *make_conv_pool_activ(64, 128, 3, Tanh, padding=1),
                *make_conv_pool_activ(128, 128, 3, Tanh, pool_size=2, padding=1)
            )
            linears = Sequential(Linear(2048, 10))
            super().__init__(convs, linears)
    
    
    class AlexNetImageNet(Classifier):
        def __init__(self):
            convs = Sequential(
                *make_conv_pool_activ(
                    3, 64, 11, ReLU, padding=2, stride=4, pool_size=3, pool_stride=2
                ),
                *make_conv_pool_activ(
                    64, 192, 5, ReLU, padding=2, pool_size=3, pool_stride=2
                ),
                *make_conv_pool_activ(192, 384, 3, ReLU, padding=1),
                *make_conv_pool_activ(384, 256, 3, ReLU, padding=1),
                *make_conv_pool_activ(
                    256, 256, 3, ReLU, padding=1, pool_size=3, pool_stride=2
                )
            )
            linears = Sequential(
                Linear(9216, 4096), ReLU(), Linear(4096, 4096), ReLU(), Linear(4096, 1000),
            )
            super().__init__(convs, linears)