import site
from pathlib import Path

import torch
from torch.nn.modules.module import Module
from torch.utils.data.dataloader import DataLoader
from torch.utils.data.dataset import Subset

site.addsitedir(Path(__file__).absolute().parent.parent)
import model_zoo as net
from predtuner import TorchApp, accuracy, get_knobs_from_file


def load_from_default_path(cls, prefix: str):
    return cls.from_file(f"{prefix}/input.bin", f"{prefix}/labels.bin")


mnist = load_from_default_path(net.MNIST, "model_data/mnist")
cifar10 = load_from_default_path(net.CIFAR, "model_data/cifar10")
cifar100 = load_from_default_path(net.CIFAR, "model_data/cifar100")
imagenet = load_from_default_path(net.ImageNet, "model_data/imagenet")

networks_in_folder = {
    "lenet_mnist": (net.LeNet, mnist),
    "alexnet_cifar10": (net.AlexNet, cifar10),
    "alexnet2_cifar10": (net.AlexNet2, cifar10),
    "vgg16_cifar10": (net.VGG16Cifar10, cifar10),
    "vgg16_cifar100": (net.VGG16Cifar100, cifar100),
}

for name, (cls, dataset) in networks_in_folder.items():
    network: Module = cls()
    network.load_state_dict(torch.load(f"model_data/{name}.pth.tar"))
    d1, d2 = DataLoader(Subset(dataset, range(5000, 10000)), 1), DataLoader(dataset, 1)
    app = TorchApp("", network, d1, d2, get_knobs_from_file(), accuracy)
    qos, _ = app.measure_qos_perf({}, False)
    print(f"{name} -> {qos}")