from pathlib import Path
from typing import Union

import numpy as np
import torch
from torch.nn import Conv2d, Linear, MaxPool2d, Module, ReLU, Sequential, Softmax


class MiniERA(Module):
    def __init__(self):
        super().__init__()
        self.convs = Sequential(
            Conv2d(3, 32, 3),
            ReLU(),
            Conv2d(32, 32, 3),
            ReLU(),
            MaxPool2d(2, 2),
            Conv2d(32, 64, 3),
            ReLU(),
            Conv2d(64, 64, 3),
            ReLU(),
            MaxPool2d(2, 2),
        )
        self.fcs = Sequential(Linear(1600, 256), ReLU(), Linear(256, 5))
        self.softmax = Softmax(1)

    def forward(self, input):
        outputs = self.convs(input)
        outputs = self.fcs(outputs.flatten(1, -1))
        return self.softmax(outputs)

    def load_legacy_hpvm_weights(self, prefix: Union[Path, str]):
        prefix = Path(prefix)
        # Load in model convolution weights
        count = 0
        for conv in self.convs:
            if not isinstance(conv, Conv2d):
                continue
            weight_np = np.fromfile(
                prefix / f"conv2d_{count+1}_w.bin", dtype=np.float32
            )
            bias_np = np.fromfile(prefix / f"conv2d_{count+1}_b.bin", dtype=np.float32)
            conv.weight.data = torch.tensor(weight_np).reshape(conv.weight.shape)
            conv.bias.data = torch.tensor(bias_np).reshape(conv.bias.shape)
            count += 1
        # Load in model fc weights
        count = 0
        for linear in self.fcs:
            if not isinstance(linear, Linear):
                continue
            weight_np = np.fromfile(prefix / f"dense_{count+1}_w.bin", dtype=np.float32)
            bias_np = np.fromfile(prefix / f"dense_{count+1}_b.bin", dtype=np.float32)
            cout, cin = linear.weight.shape
            linear.weight.data = torch.tensor(weight_np).reshape(cin, cout).T
            linear.bias.data = torch.tensor(bias_np).reshape(linear.bias.shape)
            count += 1
        return self