Newer
Older
from torch.utils.data.dataset import Subset
from predtuner.approxes import get_knobs_from_file
from predtuner.torchapp import TorchApp
from predtuner.torchutil import accuracy
from torch.utils.data.dataloader import DataLoader
dataset = CIFAR.from_file("model_data/cifar10/input.bin", "model_data/cifar10/labels.bin")
self.dataset = Subset(dataset, range(100))
self.module = VGG16Cifar10()
self.module.load_state_dict(torch.load("model_data/vgg16_cifar10.pth.tar"))
def get_app(self):
return TorchApp(
"TestTorchApp",
self.module,
DataLoader(self.dataset),
DataLoader(self.dataset),
get_knobs_from_file(),
accuracy,
)
def test_init(self):
app = self.get_app()
n_knobs = {op: len(ks) for op, ks in app.op_knobs.items()}
for op_name, op in app.midx.name_to_module.items():
if isinstance(op, Conv2d):
else:
nknob = 1
self.assertEqual(n_knobs[op_name], nknob)
def test_baseline_qos(self):
app = self.get_app()
qos, _ = app.measure_qos_perf({}, False)
self.assertAlmostEqual(qos, 0.88)
def test_tuning(self):
app = TorchApp(
"test",
self.module,
DataLoader(self.dataset, batch_size=4),
DataLoader(self.dataset, batch_size=4),
get_knobs_from_file(),
accuracy,
)
tuner = app.get_tuner()
tuner.tune(10, 3.0)