Skip to content
Snippets Groups Projects
test_torchapp.py 1.81 KiB
Newer Older
  • Learn to ignore specific revisions
  • import unittest
    
    Yifan Zhao's avatar
    Yifan Zhao committed
    import torch
    
    from torch.utils.data.dataset import Subset
    
    
    from predtuner.approxes import get_knobs_from_file
    from predtuner.torchapp import TorchApp
    from predtuner.torchutil import accuracy
    
    Yifan Zhao's avatar
    Yifan Zhao committed
    from torch.nn import Conv2d, Linear
    
    from torch.utils.data.dataloader import DataLoader
    
    Yifan Zhao's avatar
    Yifan Zhao committed
    from model_zoo import VGG16Cifar10, CIFAR
    
    class TestTorchApp(unittest.TestCase):
    
        def setUp(self):
    
    Yifan Zhao's avatar
    Yifan Zhao committed
            dataset = CIFAR.from_file("model_data/cifar10/input.bin", "model_data/cifar10/labels.bin")
    
            self.dataset = Subset(dataset, range(100))
    
    Yifan Zhao's avatar
    Yifan Zhao committed
            self.module = VGG16Cifar10()
            self.module.load_state_dict(torch.load("model_data/vgg16_cifar10.pth.tar"))
    
        def get_app(self):
            return TorchApp(
                "TestTorchApp",
    
                self.module,
                DataLoader(self.dataset),
                DataLoader(self.dataset),
                get_knobs_from_file(),
                accuracy,
            )
    
    
        def test_init(self):
            app = self.get_app()
    
    Yifan Zhao's avatar
    Yifan Zhao committed
            n_knobs = {op: len(ks) for op, ks in app.op_knobs.items()}
    
    Yifan Zhao's avatar
    Yifan Zhao committed
            self.assertEqual(len(n_knobs), 34)
    
    Yifan Zhao's avatar
    Yifan Zhao committed
            for op_name, op in app.midx.name_to_module.items():
                if isinstance(op, Conv2d):
    
    Yifan Zhao's avatar
    Yifan Zhao committed
                elif isinstance(op, Linear):
    
    Yifan Zhao's avatar
    Yifan Zhao committed
                else:
                    nknob = 1
                self.assertEqual(n_knobs[op_name], nknob)
    
    Yifan Zhao's avatar
    Yifan Zhao committed
        def test_baseline_qos(self):
            app = self.get_app()
            qos, _ = app.measure_qos_perf({}, False)
            self.assertAlmostEqual(qos, 0.88)
    
    
        def test_tuning(self):
            app = TorchApp(
                "test",
                self.module,
                DataLoader(self.dataset, batch_size=4),
                DataLoader(self.dataset, batch_size=4),
                get_knobs_from_file(),
                accuracy,
            )
            tuner = app.get_tuner()
            tuner.tune(10, 3.0)