-
Yifan Zhao authoredYifan Zhao authored
test_torchapp.py 4.76 KiB
import unittest
import torch
from predtuner.model_zoo import CIFAR, VGG16Cifar10
from predtuner import TorchApp, accuracy, config_pylogger, get_knobs_from_file
from torch.nn import Conv2d, Linear
from torch.utils.data.dataloader import DataLoader
from torch.utils.data.dataset import Subset
msg_logger = config_pylogger(output_dir="/tmp", verbose=True)
class TorchAppSetUp(unittest.TestCase):
@classmethod
def setUpClass(cls):
dataset = CIFAR.from_file(
"model_params/vgg16_cifar10/tune_input.bin",
"model_params/vgg16_cifar10/tune_labels.bin",
)
cls.dataset = Subset(dataset, range(100))
cls.module = VGG16Cifar10()
cls.module.load_state_dict(torch.load("model_params/vgg16_cifar10.pth.tar"))
cls.app = TorchApp(
"TestTorchApp",
cls.module,
DataLoader(cls.dataset, batch_size=500),
DataLoader(cls.dataset, batch_size=500),
get_knobs_from_file(),
accuracy,
)
class TestTorchAppTuning(TorchAppSetUp):
def test_knobs(self):
n_knobs = {op: len(ks) for op, ks in self.app.op_knobs.items()}
self.assertEqual(len(n_knobs), 34)
for op_name, op in self.app.midx.name_to_module.items():
if isinstance(op, Conv2d):
nknob = 29
elif isinstance(op, Linear):
nknob = 2
else:
nknob = 1
self.assertEqual(n_knobs[op_name], nknob)
def test_baseline_knob(self):
self.assertEqual(self.app.baseline_knob.name, "11")
def test_baseline_qos(self):
qos, _ = self.app.measure_qos_perf({}, False)
self.assertAlmostEqual(qos, 93.0)
def test_tuning_relative_thres(self):
baseline, _ = self.app.measure_qos_perf({}, False)
tuner = self.app.get_tuner()
tuner.tune(100, 3.0, 3.0, True, 10)
for conf in tuner.kept_configs:
self.assertTrue(conf.qos > baseline - 3.0)
if len(tuner.kept_configs) >= 10:
self.assertEqual(len(tuner.best_configs), 10)
def test_enum_models(self):
self.assertSetEqual(
set(model.name for model in self.app.get_models()),
{"perf_linear", "qos_p1", "qos_p2"},
)
class TestTorchAppTunerResult(TorchAppSetUp):
@classmethod
def setUpClass(cls):
super().setUpClass()
cls.baseline, _ = cls.app.measure_qos_perf({}, False)
cls.tuner = cls.app.get_tuner()
cls.tuner.tune(100, cls.baseline - 3.0)
def test_results_qos(self):
configs = self.tuner.kept_configs
for conf in configs:
self.assertTrue(conf.qos > self.baseline - 3.0)
def test_pareto(self):
configs = self.tuner.best_configs
for c1 in configs:
self.assertFalse(
any(c2.qos > c1.qos and c2.perf > c1.perf for c2 in configs)
)
def test_dummy_testset(self):
configs = self.tuner.best_configs
for c in configs:
self.assertAlmostEqual(c.test_qos, c.qos)
class TestModeledTuning(TorchAppSetUp):
@classmethod
def setUpClass(cls):
super().setUpClass()
cls.baseline, _ = cls.app.measure_qos_perf({}, False)
def test_qos_p1(self):
tuner = self.app.get_tuner()
tuner.tune(
100,
3.0,
is_threshold_relative=True,
perf_model="perf_linear",
qos_model="qos_p1",
)
def test_qos_p2(self):
tuner = self.app.get_tuner()
tuner.tune(
100,
3.0,
is_threshold_relative=True,
perf_model="perf_linear",
qos_model="qos_p2",
)
class TestModelSaving(TorchAppSetUp):
@classmethod
def setUpClass(cls):
super().setUpClass()
cls.baseline, _ = cls.app.measure_qos_perf({}, False)
cls.model_path = "/tmp/test_models"
app = cls.get_app()
app.init_model("qos_p1")
app.init_model("qos_p2")
cls.app = cls.get_app()
@classmethod
def get_app(cls):
return TorchApp(
"TestTorchApp",
cls.module,
DataLoader(cls.dataset, batch_size=500),
DataLoader(cls.dataset, batch_size=500),
get_knobs_from_file(),
accuracy,
model_storage_folder=cls.model_path,
)
def test_loading_p1(self):
self.app.get_tuner().tune(
100,
3.0,
is_threshold_relative=True,
perf_model="perf_linear",
qos_model="qos_p1",
)
def test_loading_p2(self):
self.app.get_tuner().tune(
100,
3.0,
is_threshold_relative=True,
perf_model="perf_linear",
qos_model="qos_p2",
)