import unittest

from torch.utils.data.dataset import Subset

from predtuner.approxes import get_knobs_from_file
from predtuner.torchapp import TorchApp
from predtuner.torchutil import accuracy
from torch.nn import Conv2d, Linear
from torch.utils.data.dataloader import DataLoader
from torchvision import transforms
from torchvision.datasets import CIFAR10
from torchvision.models.vgg import vgg16


class TestTorchApp(unittest.TestCase):
    def setUp(self):
        normalize = transforms.Normalize(
            mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]
        )
        transform = transforms.Compose([transforms.ToTensor(), normalize])

        dataset = CIFAR10("/tmp/cifar10", download=True, transform=transform)
        self.dataset = Subset(dataset, range(100))
        self.module = vgg16(pretrained=True)

    def get_app(self):
        return TorchApp(
            "TestTorchApp",
            self.module,
            DataLoader(self.dataset),
            DataLoader(self.dataset),
            get_knobs_from_file(),
            accuracy,
        )

    def test_init(self):
        app = self.get_app()
        n_knobs = {op: len(ks) for op, ks in app.op_knobs.items()}
        for op_name, op in app.midx.name_to_module.items():
            if isinstance(op, Conv2d):
                nknob = 56
            elif isinstance(op, Linear):
                nknob = 2
            else:
                nknob = 1
            self.assertEqual(n_knobs[op_name], nknob)

    # def test_baseline_qos(self):
    #     app = self.get_app()
    #     qos, _ = app.measure_qos_perf({}, False)

    def test_tuning(self):
        app = TorchApp(
            "test",
            self.module,
            DataLoader(self.dataset, batch_size=4),
            DataLoader(self.dataset, batch_size=4),
            get_knobs_from_file(),
            accuracy,
        )
        tuner = app.get_tuner()
        tuner.tune(10, 3.0)