import unittest
import torch

from torch.utils.data.dataset import Subset

from predtuner.approxes import get_knobs_from_file
from predtuner.torchapp import TorchApp
from predtuner.torchutil import accuracy
from torch.nn import Conv2d, Linear
from torch.utils.data.dataloader import DataLoader
from model_zoo import VGG16Cifar10, CIFAR


class TestTorchApp(unittest.TestCase):
    def setUp(self):
        dataset = CIFAR.from_file("model_data/cifar10/input.bin", "model_data/cifar10/labels.bin")
        self.dataset = Subset(dataset, range(100))
        self.module = VGG16Cifar10()
        self.module.load_state_dict(torch.load("model_data/vgg16_cifar10.pth.tar"))

    def get_app(self):
        return TorchApp(
            "TestTorchApp",
            self.module,
            DataLoader(self.dataset),
            DataLoader(self.dataset),
            get_knobs_from_file(),
            accuracy,
        )

    def test_init(self):
        app = self.get_app()
        n_knobs = {op: len(ks) for op, ks in app.op_knobs.items()}
        self.assertEqual(len(n_knobs), 34)
        for op_name, op in app.midx.name_to_module.items():
            if isinstance(op, Conv2d):
                nknob = 56
            elif isinstance(op, Linear):
                nknob = 2
            else:
                nknob = 1
            self.assertEqual(n_knobs[op_name], nknob)

    def test_baseline_qos(self):
        app = self.get_app()
        qos, _ = app.measure_qos_perf({}, False)
        self.assertAlmostEqual(qos, 0.88)

    def test_tuning(self):
        app = TorchApp(
            "test",
            self.module,
            DataLoader(self.dataset, batch_size=4),
            DataLoader(self.dataset, batch_size=4),
            get_knobs_from_file(),
            accuracy,
        )
        tuner = app.get_tuner()
        tuner.tune(10, 3.0)