Newer
Older
import torch
from model_zoo import CIFAR, VGG16Cifar10
from predtuner import TorchApp, accuracy, config_pylogger, get_knobs_from_file
from torch.utils.data.dataloader import DataLoader
from torch.utils.data.dataset import Subset
msg_logger = config_pylogger(output_dir="/tmp", verbose=True)
class TorchAppSetUp(unittest.TestCase):
@classmethod
def setUpClass(cls):
dataset = CIFAR.from_file(
"model_data/cifar10/input.bin", "model_data/cifar10/labels.bin"
)
cls.dataset = Subset(dataset, range(100))
cls.module = VGG16Cifar10()
cls.module.load_state_dict(torch.load("model_data/vgg16_cifar10.pth.tar"))
cls.app = TorchApp(
cls.module,
DataLoader(cls.dataset, batch_size=500),
DataLoader(cls.dataset, batch_size=500),
get_knobs_from_file(),
accuracy,
)
def test_knobs(self):
n_knobs = {op: len(ks) for op, ks in self.app.op_knobs.items()}
for op_name, op in self.app.midx.name_to_module.items():
else:
nknob = 1
self.assertEqual(n_knobs[op_name], nknob)
def test_tuning_relative_thres(self):
baseline, _ = self.app.measure_qos_perf({}, False)
tuner = self.app.get_tuner()
tuner.tune(100, 3.0, 3.0, True, 10)
for conf in tuner.kept_configs:
self.assertTrue(conf.qos > baseline - 3.0)
if len(tuner.kept_configs) >= 10:
self.assertEqual(len(tuner.best_configs), 10)
def test_enum_models(self):
self.assertSetEqual(
set(model.name for model in self.app.get_models()),
{"perf_linear", "qos_p1", "qos_p2"},
)
class TestTorchAppTunerResult(TorchAppSetUp):
@classmethod
def setUpClass(cls):
super().setUpClass()
cls.baseline, _ = cls.app.measure_qos_perf({}, False)
cls.tuner = cls.app.get_tuner()
cls.tuner.tune(100, cls.baseline - 3.0)
self.assertTrue(conf.qos > self.baseline - 3.0)
def test_pareto(self):
for c1 in configs:
self.assertFalse(
any(c2.qos > c1.qos and c2.perf > c1.perf for c2 in configs)
)
def test_dummy_calib(self):
configs = self.tuner.best_configs
for c in configs:
self.assertAlmostEqual(c.calib_qos, c.qos)
class TestModeledTuning(TorchAppSetUp):
@classmethod
def setUpClass(cls):
super().setUpClass()
cls.baseline, _ = cls.app.measure_qos_perf({}, False)
def test_qos_p1(self):
tuner = self.app.get_tuner()
tuner.tune(
100,
3.0,
is_threshold_relative=True,
perf_model="perf_linear",
qos_model="qos_p1",
)
def test_qos_p2(self):
tuner = self.app.get_tuner()
tuner.tune(
100,
3.0,
is_threshold_relative=True,
perf_model="perf_linear",
qos_model="qos_p2",
)
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
class TestModelSaving(TorchAppSetUp):
@classmethod
def setUpClass(cls):
super().setUpClass()
cls.baseline, _ = cls.app.measure_qos_perf({}, False)
cls.model_path = "/tmp/test_models"
app = cls.get_app()
app.init_model("qos_p1")
app.init_model("qos_p2")
cls.app = cls.get_app()
@classmethod
def get_app(cls):
return TorchApp(
"TestTorchApp",
cls.module,
DataLoader(cls.dataset, batch_size=500),
DataLoader(cls.dataset, batch_size=500),
get_knobs_from_file(),
accuracy,
model_storage_folder=cls.model_path
)
def test_loading_p1(self):
self.app.get_tuner().tune(
100,
3.0,
is_threshold_relative=True,
perf_model="perf_linear",
qos_model="qos_p1",
)
def test_loading_p2(self):
self.app.get_tuner().tune(
100,
3.0,
is_threshold_relative=True,
perf_model="perf_linear",
qos_model="qos_p2",
)