from torch.nn import AvgPool2d, BatchNorm2d, Linear, Module, ReLU, Sequential

from ._container import Classifier, make_conv_pool_activ


class BasicBlock(Module):
    def __init__(self, ins, outs, shortcut=False):
        super().__init__()
        stride = 2 if shortcut else 1
        self.mainline = Sequential(
            *make_conv_pool_activ(ins, outs, 3, ReLU, padding=1, stride=stride),
            *make_conv_pool_activ(outs, outs, 3, padding=1)
        )
        self.relu1 = ReLU()
        self.shortcut = (
            Sequential(*make_conv_pool_activ(ins, outs, 1, stride=stride))
            if shortcut
            else Sequential()
        )

    def forward(self, input_):
        return self.relu1(self.mainline(input_) + self.shortcut(input_))


class ResNet18(Classifier):
    def __init__(self):
        convs = Sequential(
            *make_conv_pool_activ(3, 16, 3, ReLU, padding=1),
            BasicBlock(16, 16),
            BasicBlock(16, 16),
            BasicBlock(16, 16),
            BasicBlock(16, 32, True),
            BasicBlock(32, 32),
            BasicBlock(32, 32),
            BasicBlock(32, 64, True),
            BasicBlock(64, 64),
            BasicBlock(64, 64),
            AvgPool2d(8)
        )
        linears = Sequential(Linear(64, 10))
        super().__init__(convs, linears)


class Bottleneck(Module):
    expansion = 4

    def __init__(self, in_planes, planes, stride=1):
        super(Bottleneck, self).__init__()
        self.mainline = Sequential(
            *make_conv_pool_activ(in_planes, planes, 1, stride=stride),
            BatchNorm2d(planes, eps=0.001),
            ReLU(),
            *make_conv_pool_activ(planes, planes, 3, padding=1),
            BatchNorm2d(planes, eps=0.001),
            ReLU(),
            *make_conv_pool_activ(planes, self.expansion * planes, 1),
            BatchNorm2d(self.expansion * planes, eps=0.001)
        )
        self.relu1 = ReLU()
        if stride != 1 or in_planes != self.expansion * planes:
            self.shortcut = Sequential(
                *make_conv_pool_activ(
                    in_planes, self.expansion * planes, 1, stride=stride
                ),
                BatchNorm2d(self.expansion * planes, eps=0.001)
            )
        else:
            self.shortcut = Sequential()

    def forward(self, input_):
        return self.relu1(self.mainline(input_) + self.shortcut(input_))


class ResNet50(Classifier):
    def __init__(self):
        convs = Sequential(
            *make_conv_pool_activ(
                3, 64, 7, ReLU, pool_size=3, pool_stride=2, padding=3, stride=2
            ),
            BatchNorm2d(64, eps=0.001),
            Bottleneck(64, 64),
            Bottleneck(256, 64),
            Bottleneck(256, 64),
            Bottleneck(256, 128, stride=2),
            Bottleneck(512, 128),
            Bottleneck(512, 128),
            Bottleneck(512, 128),
            Bottleneck(512, 256, stride=2),
            Bottleneck(1024, 256),
            Bottleneck(1024, 256),
            Bottleneck(1024, 256),
            Bottleneck(1024, 256),
            Bottleneck(1024, 256),
            Bottleneck(1024, 512, stride=2),
            Bottleneck(2048, 512),
            Bottleneck(2048, 512),
            AvgPool2d(7)
        )
        linears = Sequential(Linear(2048, 1000))
        super().__init__(convs, linears)