import unittest import torch from torch.utils.data.dataset import Subset from predtuner.approxes import get_knobs_from_file from predtuner.torchapp import TorchApp from predtuner.torchutil import accuracy from torch.nn import Conv2d, Linear from torch.utils.data.dataloader import DataLoader from model_zoo import VGG16Cifar10, CIFAR class TestTorchApp(unittest.TestCase): def setUp(self): dataset = CIFAR.from_file("model_data/cifar10/input.bin", "model_data/cifar10/labels.bin") self.dataset = Subset(dataset, range(100)) self.module = VGG16Cifar10() self.module.load_state_dict(torch.load("model_data/vgg16_cifar10.pth.tar")) def get_app(self): return TorchApp( "TestTorchApp", self.module, DataLoader(self.dataset), DataLoader(self.dataset), get_knobs_from_file(), accuracy, ) def test_init(self): app = self.get_app() n_knobs = {op: len(ks) for op, ks in app.op_knobs.items()} self.assertEqual(len(n_knobs), 34) for op_name, op in app.midx.name_to_module.items(): if isinstance(op, Conv2d): nknob = 56 elif isinstance(op, Linear): nknob = 2 else: nknob = 1 self.assertEqual(n_knobs[op_name], nknob) def test_baseline_qos(self): app = self.get_app() qos, _ = app.measure_qos_perf({}, False) self.assertAlmostEqual(qos, 0.88) def test_tuning(self): app = TorchApp( "test", self.module, DataLoader(self.dataset, batch_size=4), DataLoader(self.dataset, batch_size=4), get_knobs_from_file(), accuracy, ) tuner = app.get_tuner() tuner.tune(10, 3.0)