Skip to content
Snippets Groups Projects
mobilenet.py 1.61 KiB
Newer Older
  • Learn to ignore specific revisions
  • from torch.nn import AvgPool2d, BatchNorm2d, Conv2d, Linear, ReLU, Sequential
    
    from ._container import Classifier, make_conv_pool_activ
    
    
    def _make_seq(in_channels, out_channels, c_kernel_size, gc_stride, gc_kernel_size=3):
        return Sequential(
            *make_conv_pool_activ(
                in_channels,
                out_channels,
                c_kernel_size,
                bias=False,
                padding=(c_kernel_size - 1) // 2,
            ),
            BatchNorm2d(out_channels, eps=0.001),
            ReLU(),
            Conv2d(
                out_channels,
                out_channels,
                gc_kernel_size,
                bias=False,
                stride=gc_stride,
                padding=(gc_kernel_size - 1) // 2,
                groups=out_channels,
            ),
            BatchNorm2d(out_channels, eps=0.001),
            ReLU()
        )
    
    
    class MobileNet(Classifier):
        def __init__(self):
            convs = Sequential(
                _make_seq(3, 32, 3, 1),
                _make_seq(32, 64, 1, 2),
                _make_seq(64, 128, 1, 1),
                _make_seq(128, 128, 1, 2),
                _make_seq(128, 256, 1, 1),
                _make_seq(256, 256, 1, 2),
                _make_seq(256, 512, 1, 1),
                _make_seq(512, 512, 1, 1),
                _make_seq(512, 512, 1, 1),
                _make_seq(512, 512, 1, 1),
                _make_seq(512, 512, 1, 1),
                _make_seq(512, 512, 1, 2),
                _make_seq(512, 1024, 1, 1),
                *make_conv_pool_activ(1024, 1024, 1, padding=0, bias=False),
                BatchNorm2d(1024, eps=0.001),
                ReLU(),
                AvgPool2d(2)
            )
            linears = Sequential(Linear(1024, 10))
            super().__init__(convs, linears)