Skip to content
Snippets Groups Projects
Commit f9557aac authored by Yifan Zhao's avatar Yifan Zhao
Browse files

Implemented relative threshold

parent 7e4a290a
No related branches found
No related tags found
No related merge requests found
...@@ -93,7 +93,7 @@ class ApproxTuner(Generic[T]): ...@@ -93,7 +93,7 @@ class ApproxTuner(Generic[T]):
max_iter: int, max_iter: int,
qos_tuner_threshold: float, qos_tuner_threshold: float,
qos_keep_threshold: Optional[float] = None, qos_keep_threshold: Optional[float] = None,
accuracy_convention: str = "absolute", # TODO: this is_threshold_relative: bool = False,
take_best_n: Optional[int] = None, take_best_n: Optional[int] = None,
calibrate: bool = True calibrate: bool = True
# TODO: more parameters + opentuner param forwarding # TODO: more parameters + opentuner param forwarding
...@@ -106,6 +106,10 @@ class ApproxTuner(Generic[T]): ...@@ -106,6 +106,10 @@ class ApproxTuner(Generic[T]):
# By default, keep_threshold == tuner_threshold # By default, keep_threshold == tuner_threshold
opentuner_args = opentuner_default_args() opentuner_args = opentuner_default_args()
qos_keep_threshold = qos_keep_threshold or qos_tuner_threshold qos_keep_threshold = qos_keep_threshold or qos_tuner_threshold
if is_threshold_relative:
baseline_qos, _ = self.app.measure_qos_perf({}, False)
qos_tuner_threshold = baseline_qos - qos_tuner_threshold
qos_keep_threshold = baseline_qos - qos_keep_threshold
opentuner_args.test_limit = max_iter opentuner_args.test_limit = max_iter
tuner = TunerInterface( tuner = TunerInterface(
opentuner_args, self.app, qos_tuner_threshold, qos_keep_threshold, max_iter, opentuner_args, self.app, qos_tuner_threshold, qos_keep_threshold, max_iter,
......
...@@ -10,23 +10,26 @@ from torch.utils.data.dataset import Subset ...@@ -10,23 +10,26 @@ from torch.utils.data.dataset import Subset
msg_logger = config_pylogger(output_dir="/tmp", verbose=True) msg_logger = config_pylogger(output_dir="/tmp", verbose=True)
class TestTorchAppInit(unittest.TestCase): class TorchAppSetUp(unittest.TestCase):
def setUp(self): @classmethod
def setUpClass(cls):
dataset = CIFAR.from_file( dataset = CIFAR.from_file(
"model_data/cifar10/input.bin", "model_data/cifar10/labels.bin" "model_data/cifar10/input.bin", "model_data/cifar10/labels.bin"
) )
self.dataset = Subset(dataset, range(100)) cls.dataset = Subset(dataset, range(100))
self.module = VGG16Cifar10() cls.module = VGG16Cifar10()
self.module.load_state_dict(torch.load("model_data/vgg16_cifar10.pth.tar")) cls.module.load_state_dict(torch.load("model_data/vgg16_cifar10.pth.tar"))
self.app = TorchApp( cls.app = TorchApp(
"TestTorchApp", "TestTorchApp",
self.module, cls.module,
DataLoader(self.dataset, batch_size=500), DataLoader(cls.dataset, batch_size=500),
DataLoader(self.dataset, batch_size=500), DataLoader(cls.dataset, batch_size=500),
get_knobs_from_file(), get_knobs_from_file(),
accuracy, accuracy,
) )
class TestTorchAppTuning(TorchAppSetUp):
def test_knobs(self): def test_knobs(self):
n_knobs = {op: len(ks) for op, ks in self.app.op_knobs.items()} n_knobs = {op: len(ks) for op, ks in self.app.op_knobs.items()}
self.assertEqual(len(n_knobs), 34) self.assertEqual(len(n_knobs), 34)
...@@ -43,15 +46,25 @@ class TestTorchAppInit(unittest.TestCase): ...@@ -43,15 +46,25 @@ class TestTorchAppInit(unittest.TestCase):
qos, _ = self.app.measure_qos_perf({}, False) qos, _ = self.app.measure_qos_perf({}, False)
self.assertAlmostEqual(qos, 88.0) self.assertAlmostEqual(qos, 88.0)
def test_tuning_relative_thres(self):
baseline, _ = self.app.measure_qos_perf({}, False)
tuner = self.app.get_tuner()
tuner.tune(100, 3.0, 3.0, True, 10)
for conf in tuner.kept_configs:
self.assertTrue(conf.qos > baseline - 3.0)
if len(tuner.kept_configs) >= 10:
self.assertEqual(len(tuner.best_configs), 10)
class TestTorchAppTuner(TestTorchAppInit): class TestTorchAppTunerResult(TorchAppSetUp):
def setUp(self): @classmethod
super().setUp() def setUpClass(cls):
self.baseline, _ = self.app.measure_qos_perf({}, False) super().setUpClass()
self.tuner = self.app.get_tuner() cls.baseline, _ = cls.app.measure_qos_perf({}, False)
self.tuner.tune(100, self.baseline - 3.0) cls.tuner = cls.app.get_tuner()
cls.tuner.tune(100, cls.baseline - 3.0)
def test_tuning(self): def test_results_qos(self):
configs = self.tuner.kept_configs configs = self.tuner.kept_configs
for conf in configs: for conf in configs:
self.assertTrue(conf.qos > self.baseline - 3.0) self.assertTrue(conf.qos > self.baseline - 3.0)
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment