From 7a84b294fed1dc3323bc3526669a17e39dbfb8eb Mon Sep 17 00:00:00 2001 From: Yifan Zhao <yifanz16@illinois.edu> Date: Sat, 23 Jan 2021 06:57:48 -0600 Subject: [PATCH] Passed test for tuning & result fetching --- predtuner/approxapp.py | 9 +++++++-- test/test_torchapp.py | 19 ++++++++----------- 2 files changed, 15 insertions(+), 13 deletions(-) diff --git a/predtuner/approxapp.py b/predtuner/approxapp.py index bb49ed0..192cfdb 100644 --- a/predtuner/approxapp.py +++ b/predtuner/approxapp.py @@ -66,6 +66,7 @@ class ApproxTuner: self.app = app self.tune_sessions = [] self.db = None + self.keep_threshold = None def tune( self, @@ -81,8 +82,9 @@ class ApproxTuner: from opentuner.tuningrunmain import TuningRunMain # By default, keep_threshold == tuner_threshold - qos_keep_threshold = qos_keep_threshold or qos_tuner_threshold opentuner_args = opentuner_default_args() + qos_keep_threshold = qos_keep_threshold or qos_tuner_threshold + self.keep_threshold = qos_keep_threshold self.db = opentuner_args.database or f'opentuner.db/{socket.gethostname()}.db' opentuner_args.test_limit = max_iter tuner = TunerInterface( @@ -94,7 +96,7 @@ class ApproxTuner: def get_all_configs(self) -> List[Config]: from ._dbloader import read_opentuner_db - if self.db is None: + if self.db is None or self.keep_threshold is None: raise RuntimeError( f"No tuning session has been run; call self.tune() first." ) @@ -102,6 +104,7 @@ class ApproxTuner: return [ Config(result.accuracy, result.time, configuration.data) for result, configuration in rets + if result.accuracy > self.keep_threshold ] def write_configs_to_dir(self, directory: PathLike): @@ -113,6 +116,8 @@ class ApproxTuner: _, perf = self.app.measure_qos_perf({}, False) fig, ax = plt.subplots() confs = self.get_all_configs() + if not confs: + return fig qos_speedup = [(c.qos, perf / c.perf) for c in confs] qoses, speedups = zip(*sorted(qos_speedup, key=lambda p: p[0])) ax.plot(qoses, speedups) diff --git a/test/test_torchapp.py b/test/test_torchapp.py index 1621a08..cb6fd5d 100644 --- a/test/test_torchapp.py +++ b/test/test_torchapp.py @@ -22,8 +22,8 @@ class TestTorchApp(unittest.TestCase): return TorchApp( "TestTorchApp", self.module, - DataLoader(self.dataset), - DataLoader(self.dataset), + DataLoader(self.dataset, batch_size=500), + DataLoader(self.dataset, batch_size=500), get_knobs_from_file(), accuracy, ) @@ -47,13 +47,10 @@ class TestTorchApp(unittest.TestCase): self.assertAlmostEqual(qos, 88.0) def test_tuning(self): - app = TorchApp( - "test", - self.module, - DataLoader(self.dataset, batch_size=4), - DataLoader(self.dataset, batch_size=4), - get_knobs_from_file(), - accuracy, - ) + app = self.get_app() + baseline, _ = app.measure_qos_perf({}, False) tuner = app.get_tuner() - tuner.tune(10, 3.0) + tuner.tune(100, baseline - 3.0) + configs = tuner.get_all_configs() + for conf in configs: + self.assertTrue(conf.qos > baseline - 3.0) -- GitLab