import unittest

from predtuner.approxes import get_knobs_from_file
from predtuner.torchapp import TorchApp
from predtuner.torchutil import accuracy
from torch.nn import Conv2d, Linear
from torch.utils.data.dataloader import DataLoader
from torchvision import transforms
from torchvision.datasets import CIFAR10
from torchvision.models.vgg import vgg16


class TestTorchAppInit(unittest.TestCase):
    def setUp(self):
        transform = transforms.Compose([transforms.ToTensor()])
        self.dataset = CIFAR10("/tmp/cifar10", download=True, transform=transform)
        self.module = vgg16(pretrained=True)

    def test_init(self):
        app = TorchApp(
            "test",
            self.module,
            DataLoader(self.dataset),
            DataLoader(self.dataset),
            get_knobs_from_file(),
            accuracy,
        )
        n_knobs = {op: len(ks) for op, ks in app.op_knobs.items()}
        for op_name, op in app.midx.name_to_module.items():
            if isinstance(op, Conv2d):
                nknob = 63
            elif isinstance(op, Linear):
                nknob = 9
            else:
                nknob = 1
            self.assertEqual(n_knobs[op_name], nknob)