import unittest

import torch
from model_zoo import CIFAR, VGG16Cifar10
from predtuner import TorchApp, accuracy, config_pylogger, get_knobs_from_file
from torch.nn import Conv2d, Linear
from torch.utils.data.dataloader import DataLoader
from torch.utils.data.dataset import Subset

msg_logger = config_pylogger(output_dir="/tmp", verbose=True)


class TorchAppSetUp(unittest.TestCase):
    @classmethod
    def setUpClass(cls):
        dataset = CIFAR.from_file(
            "model_data/cifar10/input.bin", "model_data/cifar10/labels.bin"
        )
        cls.dataset = Subset(dataset, range(100))
        cls.module = VGG16Cifar10()
        cls.module.load_state_dict(torch.load("model_data/vgg16_cifar10.pth.tar"))
        cls.app = TorchApp(
            "TestTorchApp",
            cls.module,
            DataLoader(cls.dataset, batch_size=500),
            DataLoader(cls.dataset, batch_size=500),
            get_knobs_from_file(),
            accuracy,
        )


class TestTorchAppTuning(TorchAppSetUp):
    def test_knobs(self):
        n_knobs = {op: len(ks) for op, ks in self.app.op_knobs.items()}
        self.assertEqual(len(n_knobs), 34)
        for op_name, op in self.app.midx.name_to_module.items():
            if isinstance(op, Conv2d):
                nknob = 56
            elif isinstance(op, Linear):
                nknob = 2
            else:
                nknob = 1
            self.assertEqual(n_knobs[op_name], nknob)

    def test_baseline_qos(self):
        qos, _ = self.app.measure_qos_perf({}, False)
        self.assertAlmostEqual(qos, 88.0)

    def test_tuning_relative_thres(self):
        baseline, _ = self.app.measure_qos_perf({}, False)
        tuner = self.app.get_tuner()
        tuner.tune(100, 3.0, 3.0, True, 10)
        for conf in tuner.kept_configs:
            self.assertTrue(conf.qos > baseline - 3.0)
        if len(tuner.kept_configs) >= 10:
            self.assertEqual(len(tuner.best_configs), 10)

    def test_enum_models(self):
        self.assertSetEqual(
            set(model.name for model in self.app.get_models()),
            {"perf_linear", "qos_p1", "qos_p2"},
        )


class TestTorchAppTunerResult(TorchAppSetUp):
    @classmethod
    def setUpClass(cls):
        super().setUpClass()
        cls.baseline, _ = cls.app.measure_qos_perf({}, False)
        cls.tuner = cls.app.get_tuner()
        cls.tuner.tune(100, cls.baseline - 3.0)

    def test_results_qos(self):
        configs = self.tuner.kept_configs
        for conf in configs:
            self.assertTrue(conf.qos > self.baseline - 3.0)

    def test_pareto(self):
        configs = self.tuner.best_configs
        for c1 in configs:
            self.assertFalse(
                any(c2.qos > c1.qos and c2.perf > c1.perf for c2 in configs)
            )

    def test_dummy_calib(self):
        configs = self.tuner.best_configs
        for c in configs:
            self.assertAlmostEqual(c.calib_qos, c.qos)


class TestModeledTuning(TorchAppSetUp):
    @classmethod
    def setUpClass(cls):
        super().setUpClass()
        cls.baseline, _ = cls.app.measure_qos_perf({}, False)

    def test_qos_p1(self):
        tuner = self.app.get_tuner()
        tuner.tune(
            100,
            3.0,
            is_threshold_relative=True,
            perf_model="perf_linear",
            qos_model="qos_p1",
        )

    def test_qos_p2(self):
        tuner = self.app.get_tuner()
        tuner.tune(
            100,
            3.0,
            is_threshold_relative=True,
            perf_model="perf_linear",
            qos_model="qos_p2",
        )