From 598ffbf44d560cb6960277a0495a8c8bd3fd02eb Mon Sep 17 00:00:00 2001
From: Yifan Zhao <yifanz16@illinois.edu>
Date: Mon, 15 Mar 2021 15:57:52 -0500
Subject: [PATCH] Fixed problems with config validation and calibration

---
 predtuner/approxapp.py  | 33 ++++++++++++++++---------
 predtuner/modeledapp.py | 54 ++++++++++++++++++++++++++++++-----------
 2 files changed, 62 insertions(+), 25 deletions(-)

diff --git a/predtuner/approxapp.py b/predtuner/approxapp.py
index b34b1a1..25b2279 100644
--- a/predtuner/approxapp.py
+++ b/predtuner/approxapp.py
@@ -147,7 +147,7 @@ class ApproxTuner(Generic[T]):
         is_threshold_relative: bool = False,
         take_best_n: Optional[int] = None,
         test_configs: bool = True,
-        **kwargs
+        app_kwargs: dict = None
         # TODO: more parameters + opentuner param forwarding
     ) -> List[T]:
         from opentuner.tuningrunmain import TuningRunMain
@@ -169,7 +169,7 @@ class ApproxTuner(Generic[T]):
             qos_tuner_threshold,
             qos_keep_threshold,
             is_threshold_relative,
-            self._get_app_kwargs(**kwargs),
+            app_kwargs or {},
         )
         assert self.keep_threshold is not None
         trm = TuningRunMain(tuner, opentuner_args)
@@ -202,19 +202,33 @@ class ApproxTuner(Generic[T]):
             len(self.best_configs),
         )
         if test_configs:
-            msg_logger.info("Checking configurations on test inputs")
-            self.test_configs_(self.best_configs)
+            msg_logger.info("Calibrating configurations on test inputs")
+            self.best_configs = self.test_configs(self.best_configs)
         return self.best_configs
 
-    def test_configs_(self, configs: List[T]):
+    def test_configs(self, configs: List[Config]):
+        from copy import deepcopy
+
         from tqdm import tqdm
 
+        assert self.keep_threshold is not None
+        if not configs:
+            return []
+        ret_configs = []
+        total_error = 0
         for cfg in tqdm(configs, leave=False):
-            cfg: T
-            if cfg.test_qos is not None:
-                continue
+            cfg = deepcopy(cfg)
+            assert cfg.test_qos is None
             cfg.test_qos, _ = self.app.measure_qos_cost(cfg.knobs, True)
             msg_logger.debug(f"Calibration: {cfg.qos} (mean) -> {cfg.test_qos} (mean)")
+            total_error += abs(cfg.qos - cfg.test_qos)
+            if cfg.test_qos > self.keep_threshold:
+                ret_configs.append(cfg)
+            else:
+                msg_logger.debug("Config removed")
+        mean_err = total_error / len(configs)
+        msg_logger.info("QoS mean abs difference of calibration: %f", mean_err)
+        return ret_configs
 
     @staticmethod
     def take_best_configs(configs: List[T], n: Optional[int] = None) -> List[T]:
@@ -294,9 +308,6 @@ class ApproxTuner(Generic[T]):
             **app_kwargs,
         )
 
-    def _get_app_kwargs(self, **kwargs):
-        return {}
-
     @classmethod
     def _get_config_class(cls) -> Type[Config]:
         return Config
diff --git a/predtuner/modeledapp.py b/predtuner/modeledapp.py
index adeb533..61fb980 100644
--- a/predtuner/modeledapp.py
+++ b/predtuner/modeledapp.py
@@ -192,6 +192,8 @@ class QoSModelP1(IQoSModel):
         qos_metric: Callable[[torch.Tensor], float],
         storage: PathLike = None,
     ) -> None:
+        from torch.nn.functional import softmax
+
         super().__init__()
         self.app = app
         self.output_f = tensor_output_getter
@@ -387,32 +389,56 @@ class ApproxModeledTuner(ApproxTuner):
             qos_keep_threshold=qos_keep_threshold,
             is_threshold_relative=is_threshold_relative,
             take_best_n=take_best_n,
-            test_configs=test_configs,
-            cost_model=cost_model,
-            qos_model=qos_model,
+            test_configs=False,  # Test configs below by ourselves
+            app_kwargs={"cost_model": cost_model, "qos_model": qos_model}
         )
         if validate_configs is None and qos_model != "none":
             msg_logger.info(
                 'Validating configurations due to using qos model "%s"', qos_model
             )
-            self.validate_configs_(self.best_configs)
+            self.best_configs = self._update_configs(self.best_configs, False)
         elif validate_configs:
             msg_logger.info("Validating configurations as user requested")
-            self.validate_configs_(self.best_configs)
+            self.best_configs = self._update_configs(self.best_configs, False)
+        if test_configs:
+            msg_logger.info("Calibrating configurations on test inputs")
+            self.best_configs = self._update_configs(self.best_configs, True)
         return ret
 
-    def validate_configs_(self, configs: List[ValConfig]):
+    def _update_configs(self, configs: List[ValConfig], test_mode: bool):
+        from copy import deepcopy
+
         from tqdm import tqdm
 
+        assert self.keep_threshold is not None
+        if not configs:
+            msg_logger.info("No configurations found.")
+            return []
+        ret_configs = []
+        total_error = 0
         for cfg in tqdm(configs, leave=False):
-            cfg: ValConfig
-            if cfg.validated_qos is not None:
-                continue
-            cfg.validated_qos, _ = self.app.measure_qos_cost(cfg.knobs, False)
-            msg_logger.debug(f"Validation: {cfg.qos} (mean) -> {cfg.test_qos} (mean)")
-
-    def _get_app_kwargs(self, cost_model: str, qos_model: str):
-        return {"cost_model": cost_model, "qos_model": qos_model}
+            cfg = deepcopy(cfg)
+            qos, _ = self.app.measure_qos_cost(cfg.knobs, test_mode)
+            if test_mode:
+                assert cfg.test_qos is None
+                cfg.test_qos = qos
+                msg_logger.debug(f"Calibration: {cfg.qos} (mean) -> {qos} (mean)")
+            else:
+                assert cfg.validated_qos is None
+                cfg.validated_qos = qos
+                msg_logger.debug(f"Validation: {cfg.qos} (mean) -> {qos} (mean)")
+            total_error += abs(cfg.qos - qos)
+            if qos > self.keep_threshold:
+                ret_configs.append(cfg)
+            else:
+                msg_logger.debug("Config removed")
+        mean_err = total_error / len(configs)
+        if test_mode:
+            msg_logger.info("QoS mean abs difference of calibration: %f", mean_err)
+        else:
+            msg_logger.info("QoS mean abs difference of validation: %f", mean_err)
+        msg_logger.info("%d of %d configs remain", len(ret_configs), len(configs))
+        return ret_configs
 
     @classmethod
     def _get_config_class(cls) -> Type[Config]:
-- 
GitLab