From 69cef13cac9595f4108c0272015c2133012e2128 Mon Sep 17 00:00:00 2001
From: Yifan Zhao <yifanz16@illinois.edu>
Date: Mon, 15 Mar 2021 16:11:15 -0500
Subject: [PATCH] Improved plotting of configurations

---
 predtuner/approxapp.py  | 16 +++++++++++-----
 predtuner/modeledapp.py | 36 +++++++++++++++++++++++++++++++++++-
 2 files changed, 46 insertions(+), 6 deletions(-)

diff --git a/predtuner/approxapp.py b/predtuner/approxapp.py
index 25b2279..06717bd 100644
--- a/predtuner/approxapp.py
+++ b/predtuner/approxapp.py
@@ -114,8 +114,8 @@ class Config:
         self.test_qos: Optional[float] = test_qos
 
     @property
-    def qos_speedup(self):
-        return self.qos, 1 / self.cost
+    def speedup(self):
+        return 1 / self.cost
 
 
 T = TypeVar("T", bound=Config)
@@ -232,7 +232,7 @@ class ApproxTuner(Generic[T]):
 
     @staticmethod
     def take_best_configs(configs: List[T], n: Optional[int] = None) -> List[T]:
-        points = np.array([c.qos_speedup for c in configs])
+        points = np.array([(c.qos, c.speedup) for c in configs])
         taken_idx = is_pareto_efficient(points, take_n=n)
         return [configs[i] for i in taken_idx]
 
@@ -252,16 +252,22 @@ class ApproxTuner(Generic[T]):
             f.write(encode(confs, indent=2))
 
     def plot_configs(
-        self, show_qos_loss: bool = False, connect_best_points: bool = False
+        self,
+        show_qos_loss: bool = False,
+        connect_best_points: bool = False,
+        use_test_qos: bool = False,
     ) -> plt.Figure:
         if not self.tuned:
             raise RuntimeError(
                 f"No tuning session has been run; call self.tune() first."
             )
 
+        def qos_speedup(conf):
+            return conf.test_qos if use_test_qos else conf.qos, conf.speedup
+
         def get_points(confs):
             sorted_points = np.array(
-                sorted([c.qos_speedup for c in confs], key=lambda p: p[0])
+                sorted([qos_speedup(c) for c in confs], key=lambda p: p[0])
             ).T
             if show_qos_loss:
                 sorted_points[0] = self.baseline_qos - sorted_points[0]
diff --git a/predtuner/modeledapp.py b/predtuner/modeledapp.py
index 61fb980..a5d904a 100644
--- a/predtuner/modeledapp.py
+++ b/predtuner/modeledapp.py
@@ -5,6 +5,7 @@ import pickle
 from pathlib import Path
 from typing import Callable, Dict, Iterator, List, Optional, Tuple, Type, Union
 
+import matplotlib.pyplot as plt
 import numpy as np
 import pandas as pd
 import torch
@@ -390,7 +391,7 @@ class ApproxModeledTuner(ApproxTuner):
             is_threshold_relative=is_threshold_relative,
             take_best_n=take_best_n,
             test_configs=False,  # Test configs below by ourselves
-            app_kwargs={"cost_model": cost_model, "qos_model": qos_model}
+            app_kwargs={"cost_model": cost_model, "qos_model": qos_model},
         )
         if validate_configs is None and qos_model != "none":
             msg_logger.info(
@@ -440,6 +441,39 @@ class ApproxModeledTuner(ApproxTuner):
         msg_logger.info("%d of %d configs remain", len(ret_configs), len(configs))
         return ret_configs
 
+    def plot_configs(
+        self, show_qos_loss: bool = False, connect_best_points: bool = False
+    ) -> plt.Figure:
+        if not self.tuned:
+            raise RuntimeError(
+                f"No tuning session has been run; call self.tune() first."
+            )
+
+        def get_points(confs, validated):
+            def qos_speedup(conf):
+                return conf.validated_qos if validated else conf.qos, conf.speedup
+
+            sorted_points = np.array(
+                sorted([qos_speedup(c) for c in confs], key=lambda p: p[0])
+            ).T
+            if show_qos_loss:
+                sorted_points[0] = self.baseline_qos - sorted_points[0]
+            return sorted_points
+
+        fig, ax = plt.subplots()
+        kept_confs = get_points(self.kept_configs, False)
+        best_confs = get_points(self.best_configs, False)
+        best_confs_val = get_points(self.best_configs, True)
+        ax.plot(kept_confs[0], kept_confs[1], "o", label="valid")
+        mode = "-o" if connect_best_points else "o"
+        ax.plot(best_confs[0], best_confs[1], mode, label="best")
+        mode = "-o" if connect_best_points else "o"
+        ax.plot(best_confs_val[0], best_confs_val[1], mode, label="best_validated")
+        ax.set_xlabel("QoS Loss" if show_qos_loss else "QoS")
+        ax.set_ylabel("Speedup (x)")
+        ax.legend()
+        return fig
+
     @classmethod
     def _get_config_class(cls) -> Type[Config]:
         return ValConfig
-- 
GitLab