Newer
Older
import unittest
from predtuner.approxes import get_knobs_from_file
from predtuner.torchapp import TorchApp
from predtuner.torchutil import accuracy
from torch.utils.data.dataloader import DataLoader
from torchvision import transforms
from torchvision.datasets import CIFAR10
from torchvision.models.vgg import vgg16
class TestTorchAppInit(unittest.TestCase):
def setUp(self):
transform = transforms.Compose([transforms.ToTensor()])
self.dataset = CIFAR10("/tmp/cifar10", download=True, transform=transform)
self.module = vgg16(pretrained=True)
def test_init(self):
app = TorchApp(
"test",
self.module,
DataLoader(self.dataset),
DataLoader(self.dataset),
get_knobs_from_file(),
accuracy,
)
n_knobs = {op: len(ks) for op, ks in app.op_knobs.items()}
for op_name, op in app.midx.name_to_module.items():
if isinstance(op, Conv2d):
nknob = 63
elif isinstance(op, Linear):
nknob = 9
else:
nknob = 1
self.assertEqual(n_knobs[op_name], nknob)