Newer
Older
import torch
from model_zoo import CIFAR, VGG16Cifar10
from predtuner import TorchApp, accuracy, config_pylogger, get_knobs_from_file
from torch.utils.data.dataloader import DataLoader
from torch.utils.data.dataset import Subset
msg_logger = config_pylogger(output_dir="/tmp", verbose=True)
dataset = CIFAR.from_file(
"model_data/cifar10/input.bin", "model_data/cifar10/labels.bin"
)
self.dataset = Subset(dataset, range(100))
self.module = VGG16Cifar10()
self.module.load_state_dict(torch.load("model_data/vgg16_cifar10.pth.tar"))
DataLoader(self.dataset, batch_size=500),
DataLoader(self.dataset, batch_size=500),
get_knobs_from_file(),
accuracy,
)
def test_knobs(self):
n_knobs = {op: len(ks) for op, ks in self.app.op_knobs.items()}
for op_name, op in self.app.midx.name_to_module.items():
else:
nknob = 1
self.assertEqual(n_knobs[op_name], nknob)
class TestTorchAppTuner(TestTorchAppInit):
def setUp(self):
super().setUp()
self.baseline, _ = self.app.measure_qos_perf({}, False)
self.tuner = self.app.get_tuner()
self.tuner.tune(100, self.baseline - 3.0)
self.assertTrue(conf.qos > self.baseline - 3.0)
def test_pareto(self):
for c1 in configs:
self.assertFalse(
any(c2.qos > c1.qos and c2.perf > c1.perf for c2 in configs)
)
def test_dummy_calib(self):
configs = self.tuner.best_configs
for c in configs:
self.assertAlmostEqual(c.calib_qos, c.qos)