Skip to content
Snippets Groups Projects
test_torchapp.py 1.18 KiB
Newer Older
  • Learn to ignore specific revisions
  • import unittest
    
    from predtuner.approxes import get_knobs_from_file
    from predtuner.torchapp import TorchApp
    from predtuner.torchutil import accuracy
    
    Yifan Zhao's avatar
    Yifan Zhao committed
    from torch.nn import Conv2d, Linear
    
    from torch.utils.data.dataloader import DataLoader
    from torchvision import transforms
    from torchvision.datasets import CIFAR10
    from torchvision.models.vgg import vgg16
    
    
    class TestTorchAppInit(unittest.TestCase):
        def setUp(self):
            transform = transforms.Compose([transforms.ToTensor()])
            self.dataset = CIFAR10("/tmp/cifar10", download=True, transform=transform)
            self.module = vgg16(pretrained=True)
    
        def test_init(self):
            app = TorchApp(
                "test",
                self.module,
                DataLoader(self.dataset),
                DataLoader(self.dataset),
                get_knobs_from_file(),
                accuracy,
            )
    
    Yifan Zhao's avatar
    Yifan Zhao committed
            n_knobs = {op: len(ks) for op, ks in app.op_knobs.items()}
            for op_name, op in app.midx.name_to_module.items():
                if isinstance(op, Conv2d):
                    nknob = 63
                elif isinstance(op, Linear):
                    nknob = 9
                else:
                    nknob = 1
                self.assertEqual(n_knobs[op_name], nknob)