diff --git a/predtuner/approxapp.py b/predtuner/approxapp.py index 1ca593f1320117a90f13ed1fa2c6b5bd86591123..a4c276afb6ecd3f86b3a90a8704d3a5247e9ea0c 100644 --- a/predtuner/approxapp.py +++ b/predtuner/approxapp.py @@ -93,7 +93,7 @@ class ApproxTuner(Generic[T]): max_iter: int, qos_tuner_threshold: float, qos_keep_threshold: Optional[float] = None, - accuracy_convention: str = "absolute", # TODO: this + is_threshold_relative: bool = False, take_best_n: Optional[int] = None, calibrate: bool = True # TODO: more parameters + opentuner param forwarding @@ -106,6 +106,10 @@ class ApproxTuner(Generic[T]): # By default, keep_threshold == tuner_threshold opentuner_args = opentuner_default_args() qos_keep_threshold = qos_keep_threshold or qos_tuner_threshold + if is_threshold_relative: + baseline_qos, _ = self.app.measure_qos_perf({}, False) + qos_tuner_threshold = baseline_qos - qos_tuner_threshold + qos_keep_threshold = baseline_qos - qos_keep_threshold opentuner_args.test_limit = max_iter tuner = TunerInterface( opentuner_args, self.app, qos_tuner_threshold, qos_keep_threshold, max_iter, diff --git a/test/test_torchapp.py b/test/test_torchapp.py index 117a6f2e05621aee903af96bc5eb80759037c6b6..d2752d7115872a52ff7df3c29fb1a3e0c57bbebc 100644 --- a/test/test_torchapp.py +++ b/test/test_torchapp.py @@ -10,23 +10,26 @@ from torch.utils.data.dataset import Subset msg_logger = config_pylogger(output_dir="/tmp", verbose=True) -class TestTorchAppInit(unittest.TestCase): - def setUp(self): +class TorchAppSetUp(unittest.TestCase): + @classmethod + def setUpClass(cls): dataset = CIFAR.from_file( "model_data/cifar10/input.bin", "model_data/cifar10/labels.bin" ) - self.dataset = Subset(dataset, range(100)) - self.module = VGG16Cifar10() - self.module.load_state_dict(torch.load("model_data/vgg16_cifar10.pth.tar")) - self.app = TorchApp( + cls.dataset = Subset(dataset, range(100)) + cls.module = VGG16Cifar10() + cls.module.load_state_dict(torch.load("model_data/vgg16_cifar10.pth.tar")) + cls.app = TorchApp( "TestTorchApp", - self.module, - DataLoader(self.dataset, batch_size=500), - DataLoader(self.dataset, batch_size=500), + cls.module, + DataLoader(cls.dataset, batch_size=500), + DataLoader(cls.dataset, batch_size=500), get_knobs_from_file(), accuracy, ) + +class TestTorchAppTuning(TorchAppSetUp): def test_knobs(self): n_knobs = {op: len(ks) for op, ks in self.app.op_knobs.items()} self.assertEqual(len(n_knobs), 34) @@ -43,15 +46,25 @@ class TestTorchAppInit(unittest.TestCase): qos, _ = self.app.measure_qos_perf({}, False) self.assertAlmostEqual(qos, 88.0) + def test_tuning_relative_thres(self): + baseline, _ = self.app.measure_qos_perf({}, False) + tuner = self.app.get_tuner() + tuner.tune(100, 3.0, 3.0, True, 10) + for conf in tuner.kept_configs: + self.assertTrue(conf.qos > baseline - 3.0) + if len(tuner.kept_configs) >= 10: + self.assertEqual(len(tuner.best_configs), 10) + -class TestTorchAppTuner(TestTorchAppInit): - def setUp(self): - super().setUp() - self.baseline, _ = self.app.measure_qos_perf({}, False) - self.tuner = self.app.get_tuner() - self.tuner.tune(100, self.baseline - 3.0) +class TestTorchAppTunerResult(TorchAppSetUp): + @classmethod + def setUpClass(cls): + super().setUpClass() + cls.baseline, _ = cls.app.measure_qos_perf({}, False) + cls.tuner = cls.app.get_tuner() + cls.tuner.tune(100, cls.baseline - 3.0) - def test_tuning(self): + def test_results_qos(self): configs = self.tuner.kept_configs for conf in configs: self.assertTrue(conf.qos > self.baseline - 3.0)