Skip to content
Snippets Groups Projects
test_torchapp.py 1.85 KiB
Newer Older
  • Learn to ignore specific revisions
  • import unittest
    
    Yifan Zhao's avatar
    Yifan Zhao committed
    import torch
    from model_zoo import CIFAR, VGG16Cifar10
    from predtuner import TorchApp, accuracy, config_pylogger, get_knobs_from_file
    
    Yifan Zhao's avatar
    Yifan Zhao committed
    from torch.nn import Conv2d, Linear
    
    from torch.utils.data.dataloader import DataLoader
    
    Yifan Zhao's avatar
    Yifan Zhao committed
    from torch.utils.data.dataset import Subset
    
    msg_logger = config_pylogger(output_dir="/tmp", verbose=True)
    
    class TestTorchApp(unittest.TestCase):
    
        def setUp(self):
    
    Yifan Zhao's avatar
    Yifan Zhao committed
            dataset = CIFAR.from_file(
                "model_data/cifar10/input.bin", "model_data/cifar10/labels.bin"
            )
    
            self.dataset = Subset(dataset, range(100))
    
    Yifan Zhao's avatar
    Yifan Zhao committed
            self.module = VGG16Cifar10()
            self.module.load_state_dict(torch.load("model_data/vgg16_cifar10.pth.tar"))
    
        def get_app(self):
            return TorchApp(
                "TestTorchApp",
    
                self.module,
    
                DataLoader(self.dataset, batch_size=500),
                DataLoader(self.dataset, batch_size=500),
    
                get_knobs_from_file(),
                accuracy,
            )
    
    
        def test_init(self):
            app = self.get_app()
    
    Yifan Zhao's avatar
    Yifan Zhao committed
            n_knobs = {op: len(ks) for op, ks in app.op_knobs.items()}
    
    Yifan Zhao's avatar
    Yifan Zhao committed
            self.assertEqual(len(n_knobs), 34)
    
    Yifan Zhao's avatar
    Yifan Zhao committed
            for op_name, op in app.midx.name_to_module.items():
                if isinstance(op, Conv2d):
    
    Yifan Zhao's avatar
    Yifan Zhao committed
                elif isinstance(op, Linear):
    
    Yifan Zhao's avatar
    Yifan Zhao committed
                else:
                    nknob = 1
                self.assertEqual(n_knobs[op_name], nknob)
    
    Yifan Zhao's avatar
    Yifan Zhao committed
        def test_baseline_qos(self):
            app = self.get_app()
            qos, _ = app.measure_qos_perf({}, False)
    
            self.assertAlmostEqual(qos, 88.0)
    
    
        def test_tuning(self):
    
            app = self.get_app()
            baseline, _ = app.measure_qos_perf({}, False)
    
            tuner = app.get_tuner()
    
            tuner.tune(100, baseline - 3.0)
    
            configs = tuner.kept_configs
    
            for conf in configs:
                self.assertTrue(conf.qos > baseline - 3.0)