From 7a84b294fed1dc3323bc3526669a17e39dbfb8eb Mon Sep 17 00:00:00 2001
From: Yifan Zhao <yifanz16@illinois.edu>
Date: Sat, 23 Jan 2021 06:57:48 -0600
Subject: [PATCH] Passed test for tuning & result fetching

---
 predtuner/approxapp.py |  9 +++++++--
 test/test_torchapp.py  | 19 ++++++++-----------
 2 files changed, 15 insertions(+), 13 deletions(-)

diff --git a/predtuner/approxapp.py b/predtuner/approxapp.py
index bb49ed0..192cfdb 100644
--- a/predtuner/approxapp.py
+++ b/predtuner/approxapp.py
@@ -66,6 +66,7 @@ class ApproxTuner:
         self.app = app
         self.tune_sessions = []
         self.db = None
+        self.keep_threshold = None
 
     def tune(
         self,
@@ -81,8 +82,9 @@ class ApproxTuner:
         from opentuner.tuningrunmain import TuningRunMain
 
         # By default, keep_threshold == tuner_threshold
-        qos_keep_threshold = qos_keep_threshold or qos_tuner_threshold
         opentuner_args = opentuner_default_args()
+        qos_keep_threshold = qos_keep_threshold or qos_tuner_threshold
+        self.keep_threshold = qos_keep_threshold
         self.db = opentuner_args.database or f'opentuner.db/{socket.gethostname()}.db'
         opentuner_args.test_limit = max_iter
         tuner = TunerInterface(
@@ -94,7 +96,7 @@ class ApproxTuner:
     def get_all_configs(self) -> List[Config]:
         from ._dbloader import read_opentuner_db
 
-        if self.db is None:
+        if self.db is None or self.keep_threshold is None:
             raise RuntimeError(
                 f"No tuning session has been run; call self.tune() first."
             )
@@ -102,6 +104,7 @@ class ApproxTuner:
         return [
             Config(result.accuracy, result.time, configuration.data)
             for result, configuration in rets
+            if result.accuracy > self.keep_threshold
         ]
 
     def write_configs_to_dir(self, directory: PathLike):
@@ -113,6 +116,8 @@ class ApproxTuner:
         _, perf = self.app.measure_qos_perf({}, False)
         fig, ax = plt.subplots()
         confs = self.get_all_configs()
+        if not confs:
+            return fig
         qos_speedup = [(c.qos, perf / c.perf) for c in confs]
         qoses, speedups = zip(*sorted(qos_speedup, key=lambda p: p[0]))
         ax.plot(qoses, speedups)
diff --git a/test/test_torchapp.py b/test/test_torchapp.py
index 1621a08..cb6fd5d 100644
--- a/test/test_torchapp.py
+++ b/test/test_torchapp.py
@@ -22,8 +22,8 @@ class TestTorchApp(unittest.TestCase):
         return TorchApp(
             "TestTorchApp",
             self.module,
-            DataLoader(self.dataset),
-            DataLoader(self.dataset),
+            DataLoader(self.dataset, batch_size=500),
+            DataLoader(self.dataset, batch_size=500),
             get_knobs_from_file(),
             accuracy,
         )
@@ -47,13 +47,10 @@ class TestTorchApp(unittest.TestCase):
         self.assertAlmostEqual(qos, 88.0)
 
     def test_tuning(self):
-        app = TorchApp(
-            "test",
-            self.module,
-            DataLoader(self.dataset, batch_size=4),
-            DataLoader(self.dataset, batch_size=4),
-            get_knobs_from_file(),
-            accuracy,
-        )
+        app = self.get_app()
+        baseline, _ = app.measure_qos_perf({}, False)
         tuner = app.get_tuner()
-        tuner.tune(10, 3.0)
+        tuner.tune(100, baseline - 3.0)
+        configs = tuner.get_all_configs()
+        for conf in configs:
+            self.assertTrue(conf.qos > baseline - 3.0)
-- 
GitLab