import unittest import torch from model_zoo import CIFAR, VGG16Cifar10 from predtuner import TorchApp, accuracy, config_pylogger, get_knobs_from_file from torch.nn import Conv2d, Linear from torch.utils.data.dataloader import DataLoader from torch.utils.data.dataset import Subset msg_logger = config_pylogger(output_dir="/tmp", verbose=True) class TestTorchAppInit(unittest.TestCase): def setUp(self): dataset = CIFAR.from_file( "model_data/cifar10/input.bin", "model_data/cifar10/labels.bin" ) self.dataset = Subset(dataset, range(100)) self.module = VGG16Cifar10() self.module.load_state_dict(torch.load("model_data/vgg16_cifar10.pth.tar")) self.app = TorchApp( "TestTorchApp", self.module, DataLoader(self.dataset, batch_size=500), DataLoader(self.dataset, batch_size=500), get_knobs_from_file(), accuracy, ) def test_knobs(self): n_knobs = {op: len(ks) for op, ks in self.app.op_knobs.items()} self.assertEqual(len(n_knobs), 34) for op_name, op in self.app.midx.name_to_module.items(): if isinstance(op, Conv2d): nknob = 56 elif isinstance(op, Linear): nknob = 2 else: nknob = 1 self.assertEqual(n_knobs[op_name], nknob) def test_baseline_qos(self): qos, _ = self.app.measure_qos_perf({}, False) self.assertAlmostEqual(qos, 88.0) class TestTorchAppTuner(TestTorchAppInit): def setUp(self): super().setUp() self.baseline, _ = self.app.measure_qos_perf({}, False) self.tuner = self.app.get_tuner() self.tuner.tune(100, self.baseline - 3.0) def test_tuning(self): configs = self.tuner.kept_configs for conf in configs: self.assertTrue(conf.qos > self.baseline - 3.0) def test_pareto(self): configs = self.tuner.best_configs for c1 in configs: self.assertFalse( any(c2.qos > c1.qos and c2.perf > c1.perf for c2 in configs) ) def test_dummy_calib(self): configs = self.tuner.best_configs for c in configs: self.assertAlmostEqual(c.calib_qos, c.qos)