import unittest

from predtuner.approxes import get_knobs_from_file
from predtuner.torchapp import TorchApp
from predtuner.torchutil import accuracy
from torch.utils.data.dataloader import DataLoader
from torchvision import transforms
from torchvision.datasets import CIFAR10
from torchvision.models.vgg import vgg16


class TestTorchAppInit(unittest.TestCase):
    def setUp(self):
        transform = transforms.Compose([transforms.ToTensor()])
        self.dataset = CIFAR10("/tmp/cifar10", download=True, transform=transform)
        self.module = vgg16(pretrained=True)

    def test_init(self):
        app = TorchApp(
            "test",
            self.module,
            DataLoader(self.dataset),
            DataLoader(self.dataset),
            get_knobs_from_file(),
            accuracy,
        )