Skip to content
Snippets Groups Projects
Commit f9557aac authored by Yifan Zhao's avatar Yifan Zhao
Browse files

Implemented relative threshold

parent 7e4a290a
No related branches found
No related tags found
No related merge requests found
......@@ -93,7 +93,7 @@ class ApproxTuner(Generic[T]):
max_iter: int,
qos_tuner_threshold: float,
qos_keep_threshold: Optional[float] = None,
accuracy_convention: str = "absolute", # TODO: this
is_threshold_relative: bool = False,
take_best_n: Optional[int] = None,
calibrate: bool = True
# TODO: more parameters + opentuner param forwarding
......@@ -106,6 +106,10 @@ class ApproxTuner(Generic[T]):
# By default, keep_threshold == tuner_threshold
opentuner_args = opentuner_default_args()
qos_keep_threshold = qos_keep_threshold or qos_tuner_threshold
if is_threshold_relative:
baseline_qos, _ = self.app.measure_qos_perf({}, False)
qos_tuner_threshold = baseline_qos - qos_tuner_threshold
qos_keep_threshold = baseline_qos - qos_keep_threshold
opentuner_args.test_limit = max_iter
tuner = TunerInterface(
opentuner_args, self.app, qos_tuner_threshold, qos_keep_threshold, max_iter,
......
......@@ -10,23 +10,26 @@ from torch.utils.data.dataset import Subset
msg_logger = config_pylogger(output_dir="/tmp", verbose=True)
class TestTorchAppInit(unittest.TestCase):
def setUp(self):
class TorchAppSetUp(unittest.TestCase):
@classmethod
def setUpClass(cls):
dataset = CIFAR.from_file(
"model_data/cifar10/input.bin", "model_data/cifar10/labels.bin"
)
self.dataset = Subset(dataset, range(100))
self.module = VGG16Cifar10()
self.module.load_state_dict(torch.load("model_data/vgg16_cifar10.pth.tar"))
self.app = TorchApp(
cls.dataset = Subset(dataset, range(100))
cls.module = VGG16Cifar10()
cls.module.load_state_dict(torch.load("model_data/vgg16_cifar10.pth.tar"))
cls.app = TorchApp(
"TestTorchApp",
self.module,
DataLoader(self.dataset, batch_size=500),
DataLoader(self.dataset, batch_size=500),
cls.module,
DataLoader(cls.dataset, batch_size=500),
DataLoader(cls.dataset, batch_size=500),
get_knobs_from_file(),
accuracy,
)
class TestTorchAppTuning(TorchAppSetUp):
def test_knobs(self):
n_knobs = {op: len(ks) for op, ks in self.app.op_knobs.items()}
self.assertEqual(len(n_knobs), 34)
......@@ -43,15 +46,25 @@ class TestTorchAppInit(unittest.TestCase):
qos, _ = self.app.measure_qos_perf({}, False)
self.assertAlmostEqual(qos, 88.0)
def test_tuning_relative_thres(self):
baseline, _ = self.app.measure_qos_perf({}, False)
tuner = self.app.get_tuner()
tuner.tune(100, 3.0, 3.0, True, 10)
for conf in tuner.kept_configs:
self.assertTrue(conf.qos > baseline - 3.0)
if len(tuner.kept_configs) >= 10:
self.assertEqual(len(tuner.best_configs), 10)
class TestTorchAppTuner(TestTorchAppInit):
def setUp(self):
super().setUp()
self.baseline, _ = self.app.measure_qos_perf({}, False)
self.tuner = self.app.get_tuner()
self.tuner.tune(100, self.baseline - 3.0)
class TestTorchAppTunerResult(TorchAppSetUp):
@classmethod
def setUpClass(cls):
super().setUpClass()
cls.baseline, _ = cls.app.measure_qos_perf({}, False)
cls.tuner = cls.app.get_tuner()
cls.tuner.tune(100, cls.baseline - 3.0)
def test_tuning(self):
def test_results_qos(self):
configs = self.tuner.kept_configs
for conf in configs:
self.assertTrue(conf.qos > self.baseline - 3.0)
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment