from torch.nn import AvgPool2d, BatchNorm2d, Conv2d, Linear, ReLU, Sequential

from ._container import Classifier, make_conv_pool_activ


def _make_seq(in_channels, out_channels, c_kernel_size, gc_stride, gc_kernel_size=3):
    return Sequential(
        *make_conv_pool_activ(
            in_channels,
            out_channels,
            c_kernel_size,
            bias=False,
            padding=(c_kernel_size - 1) // 2,
        ),
        BatchNorm2d(out_channels, eps=0.001),
        ReLU(),
        Conv2d(
            out_channels,
            out_channels,
            gc_kernel_size,
            bias=False,
            stride=gc_stride,
            padding=(gc_kernel_size - 1) // 2,
            groups=out_channels,
        ),
        BatchNorm2d(out_channels, eps=0.001),
        ReLU()
    )


class MobileNet(Classifier):
    def __init__(self):
        convs = Sequential(
            _make_seq(3, 32, 3, 1),
            _make_seq(32, 64, 1, 2),
            _make_seq(64, 128, 1, 1),
            _make_seq(128, 128, 1, 2),
            _make_seq(128, 256, 1, 1),
            _make_seq(256, 256, 1, 2),
            _make_seq(256, 512, 1, 1),
            _make_seq(512, 512, 1, 1),
            _make_seq(512, 512, 1, 1),
            _make_seq(512, 512, 1, 1),
            _make_seq(512, 512, 1, 1),
            _make_seq(512, 512, 1, 2),
            _make_seq(512, 1024, 1, 1),
            *make_conv_pool_activ(1024, 1024, 1, padding=0, bias=False),
            BatchNorm2d(1024, eps=0.001),
            ReLU(),
            AvgPool2d(2)
        )
        linears = Sequential(Linear(1024, 10))
        super().__init__(convs, linears)