import unittest from predtuner.approxes import get_knobs_from_file from predtuner.torchapp import TorchApp from predtuner.torchutil import accuracy from torch.utils.data.dataloader import DataLoader from torchvision import transforms from torchvision.datasets import CIFAR10 from torchvision.models.vgg import vgg16 class TestTorchAppInit(unittest.TestCase): def setUp(self): transform = transforms.Compose([transforms.ToTensor()]) self.dataset = CIFAR10("/tmp/cifar10", download=True, transform=transform) self.module = vgg16(pretrained=True) def test_init(self): app = TorchApp( "test", self.module, DataLoader(self.dataset), DataLoader(self.dataset), get_knobs_from_file(), accuracy, )