Skip to content
Snippets Groups Projects
test_torchapp.py 1.89 KiB
Newer Older
  • Learn to ignore specific revisions
  • import unittest
    
    
    from torch.utils.data.dataset import Subset
    
    
    from predtuner.approxes import get_knobs_from_file
    from predtuner.torchapp import TorchApp
    from predtuner.torchutil import accuracy
    
    Yifan Zhao's avatar
    Yifan Zhao committed
    from torch.nn import Conv2d, Linear
    
    from torch.utils.data.dataloader import DataLoader
    from torchvision import transforms
    from torchvision.datasets import CIFAR10
    from torchvision.models.vgg import vgg16
    
    
    
    class TestTorchApp(unittest.TestCase):
    
        def setUp(self):
    
            normalize = transforms.Normalize(
                mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]
            )
            transform = transforms.Compose([transforms.ToTensor(), normalize])
    
            dataset = CIFAR10("/tmp/cifar10", download=True, transform=transform)
            self.dataset = Subset(dataset, range(100))
    
            self.module = vgg16(pretrained=True)
    
    
        def get_app(self):
            return TorchApp(
                "TestTorchApp",
    
                self.module,
                DataLoader(self.dataset),
                DataLoader(self.dataset),
                get_knobs_from_file(),
                accuracy,
            )
    
    
        def test_init(self):
            app = self.get_app()
    
    Yifan Zhao's avatar
    Yifan Zhao committed
            n_knobs = {op: len(ks) for op, ks in app.op_knobs.items()}
            for op_name, op in app.midx.name_to_module.items():
                if isinstance(op, Conv2d):
    
    Yifan Zhao's avatar
    Yifan Zhao committed
                elif isinstance(op, Linear):
    
    Yifan Zhao's avatar
    Yifan Zhao committed
                else:
                    nknob = 1
                self.assertEqual(n_knobs[op_name], nknob)
    
    
        # def test_baseline_qos(self):
        #     app = self.get_app()
        #     qos, _ = app.measure_qos_perf({}, False)
    
        def test_tuning(self):
            app = TorchApp(
                "test",
                self.module,
                DataLoader(self.dataset, batch_size=4),
                DataLoader(self.dataset, batch_size=4),
                get_knobs_from_file(),
                accuracy,
            )
            tuner = app.get_tuner()
            tuner.tune(10, 3.0)