import unittest from torch.utils.data.dataset import Subset from predtuner.approxes import get_knobs_from_file from predtuner.torchapp import TorchApp from predtuner.torchutil import accuracy from torch.nn import Conv2d, Linear from torch.utils.data.dataloader import DataLoader from torchvision import transforms from torchvision.datasets import CIFAR10 from torchvision.models.vgg import vgg16 class TestTorchApp(unittest.TestCase): def setUp(self): normalize = transforms.Normalize( mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225] ) transform = transforms.Compose([transforms.ToTensor(), normalize]) dataset = CIFAR10("/tmp/cifar10", download=True, transform=transform) self.dataset = Subset(dataset, range(100)) self.module = vgg16(pretrained=True) def get_app(self): return TorchApp( "TestTorchApp", self.module, DataLoader(self.dataset), DataLoader(self.dataset), get_knobs_from_file(), accuracy, ) def test_init(self): app = self.get_app() n_knobs = {op: len(ks) for op, ks in app.op_knobs.items()} for op_name, op in app.midx.name_to_module.items(): if isinstance(op, Conv2d): nknob = 56 elif isinstance(op, Linear): nknob = 2 else: nknob = 1 self.assertEqual(n_knobs[op_name], nknob) # def test_baseline_qos(self): # app = self.get_app() # qos, _ = app.measure_qos_perf({}, False) def test_tuning(self): app = TorchApp( "test", self.module, DataLoader(self.dataset, batch_size=4), DataLoader(self.dataset, batch_size=4), get_knobs_from_file(), accuracy, ) tuner = app.get_tuner() tuner.tune(10, 3.0)