Skip to content
Snippets Groups Projects
test_torchapp.py 813 B
Newer Older
  • Learn to ignore specific revisions
  • import unittest
    
    from predtuner.approxes import get_knobs_from_file
    from predtuner.torchapp import TorchApp
    from predtuner.torchutil import accuracy
    from torch.utils.data.dataloader import DataLoader
    from torchvision import transforms
    from torchvision.datasets import CIFAR10
    from torchvision.models.vgg import vgg16
    
    
    class TestTorchAppInit(unittest.TestCase):
        def setUp(self):
            transform = transforms.Compose([transforms.ToTensor()])
            self.dataset = CIFAR10("/tmp/cifar10", download=True, transform=transform)
            self.module = vgg16(pretrained=True)
    
        def test_init(self):
            app = TorchApp(
                "test",
                self.module,
                DataLoader(self.dataset),
                DataLoader(self.dataset),
                get_knobs_from_file(),
                accuracy,
            )