"""Make PyTorch checkpoint of MiniERA model from legacy HPVM weights."""
import site
from pathlib import Path

import torch

self_folder = Path(__file__).parent.absolute()
site.addsitedir(self_folder)

from torch_dnn import CIFAR, MiniERA


@torch.no_grad()
def main():
    prefix = self_folder / "assets/miniera"
    model = MiniERA().load_legacy_hpvm_weights(prefix)
    # Test mini ERA
    dataset = CIFAR.from_file(prefix / "input.bin", prefix / "labels.bin")
    dataloader = torch.utils.data.DataLoader(dataset, batch_size=1, shuffle=False)
    correct = 0
    total = 0
    for data in dataloader:
        images, labels = data[0], data[1]
        outputs = model(images)
        _, predicted = torch.max(outputs.data, 1)
        total += labels.size(0)
        correct += (predicted == labels).sum().item()
    print(f"Accuracy of the network on the test images: {100 * correct / total} %")
    torch.save(model.state_dict(), prefix / "miniera.pth")


if __name__ == "__main__":
    main()