Skip to content
Snippets Groups Projects
test_torchapp.py 2.25 KiB
Newer Older
  • Learn to ignore specific revisions
  • import unittest
    
    Yifan Zhao's avatar
    Yifan Zhao committed
    import torch
    from model_zoo import CIFAR, VGG16Cifar10
    from predtuner import TorchApp, accuracy, config_pylogger, get_knobs_from_file
    
    Yifan Zhao's avatar
    Yifan Zhao committed
    from torch.nn import Conv2d, Linear
    
    from torch.utils.data.dataloader import DataLoader
    
    Yifan Zhao's avatar
    Yifan Zhao committed
    from torch.utils.data.dataset import Subset
    
    msg_logger = config_pylogger(output_dir="/tmp", verbose=True)
    
    Yifan Zhao's avatar
    Yifan Zhao committed
    class TestTorchAppInit(unittest.TestCase):
    
        def setUp(self):
    
    Yifan Zhao's avatar
    Yifan Zhao committed
            dataset = CIFAR.from_file(
                "model_data/cifar10/input.bin", "model_data/cifar10/labels.bin"
            )
    
            self.dataset = Subset(dataset, range(100))
    
    Yifan Zhao's avatar
    Yifan Zhao committed
            self.module = VGG16Cifar10()
            self.module.load_state_dict(torch.load("model_data/vgg16_cifar10.pth.tar"))
    
    Yifan Zhao's avatar
    Yifan Zhao committed
            self.app = TorchApp(
    
                "TestTorchApp",
    
                self.module,
    
                DataLoader(self.dataset, batch_size=500),
                DataLoader(self.dataset, batch_size=500),
    
                get_knobs_from_file(),
                accuracy,
            )
    
    Yifan Zhao's avatar
    Yifan Zhao committed
        def test_knobs(self):
            n_knobs = {op: len(ks) for op, ks in self.app.op_knobs.items()}
    
    Yifan Zhao's avatar
    Yifan Zhao committed
            self.assertEqual(len(n_knobs), 34)
    
    Yifan Zhao's avatar
    Yifan Zhao committed
            for op_name, op in self.app.midx.name_to_module.items():
    
    Yifan Zhao's avatar
    Yifan Zhao committed
                if isinstance(op, Conv2d):
    
    Yifan Zhao's avatar
    Yifan Zhao committed
                elif isinstance(op, Linear):
    
    Yifan Zhao's avatar
    Yifan Zhao committed
                else:
                    nknob = 1
                self.assertEqual(n_knobs[op_name], nknob)
    
    Yifan Zhao's avatar
    Yifan Zhao committed
        def test_baseline_qos(self):
    
    Yifan Zhao's avatar
    Yifan Zhao committed
            qos, _ = self.app.measure_qos_perf({}, False)
    
            self.assertAlmostEqual(qos, 88.0)
    
    Yifan Zhao's avatar
    Yifan Zhao committed
    
    class TestTorchAppTuner(TestTorchAppInit):
        def setUp(self):
            super().setUp()
            self.baseline, _ = self.app.measure_qos_perf({}, False)
            self.tuner = self.app.get_tuner()
            self.tuner.tune(100, self.baseline - 3.0)
    
    
        def test_tuning(self):
    
    Yifan Zhao's avatar
    Yifan Zhao committed
            configs = self.tuner.kept_configs
    
            for conf in configs:
    
    Yifan Zhao's avatar
    Yifan Zhao committed
                self.assertTrue(conf.qos > self.baseline - 3.0)
    
        def test_pareto(self):
    
    Yifan Zhao's avatar
    Yifan Zhao committed
            configs = self.tuner.best_configs
    
    Yifan Zhao's avatar
    Yifan Zhao committed
            for c1 in configs:
                self.assertFalse(
                    any(c2.qos > c1.qos and c2.perf > c1.perf for c2 in configs)
                )
    
    Yifan Zhao's avatar
    Yifan Zhao committed
    
        def test_dummy_calib(self):
            configs = self.tuner.best_configs
            for c in configs:
                self.assertAlmostEqual(c.calib_qos, c.qos)