Skip to content
Snippets Groups Projects
Commit dfdd5318 authored by Yifan Zhao's avatar Yifan Zhao
Browse files

Plot configs better

parent def1ad89
No related branches found
No related tags found
No related merge requests found
......@@ -133,6 +133,7 @@ class ApproxTuner(Generic[T]):
self.best_configs = []
# The following will be filled after self.tune() is called
self.keep_threshold = None
self.baseline_qos = None
@property
def tuned(self) -> bool:
......@@ -235,22 +236,31 @@ class ApproxTuner(Generic[T]):
with filepath.open("w") as f:
f.write(encode(self.best_configs, indent=2))
def plot_configs(self) -> plt.Figure:
def plot_configs(
self, show_qos_loss: bool = False, connect_best_points: bool = False
) -> plt.Figure:
if not self.tuned:
raise RuntimeError(
f"No tuning session has been run; call self.tune() first."
)
def get_points(confs):
sorted_points = np.array(
sorted([c.qos_speedup for c in confs], key=lambda p: p[0])
).T
if show_qos_loss:
sorted_points[0] = self.baseline_qos - sorted_points[0]
return sorted_points
fig, ax = plt.subplots()
confs = self.kept_configs
if not confs:
return fig
qos_speedup = [c.qos_speedup for c in confs]
qoses, speedups = zip(*sorted(qos_speedup, key=lambda p: p[0]))
ax.plot(qoses, speedups)
ax.scatter(qoses, speedups)
ax.set_xlabel("qos")
ax.set_ylabel("speedup")
kept_confs = get_points(self.kept_configs)
best_confs = get_points(self.best_configs)
ax.plot(kept_confs[0], kept_confs[1], "o", label="valid")
mode = "-o" if connect_best_points else "o"
ax.plot(best_confs[0], best_confs[1], mode, label="best")
ax.set_xlabel("QoS Loss" if show_qos_loss else "QoS")
ax.set_ylabel("Speedup (x)")
ax.legend()
return fig
def _get_tuner_interface(
......@@ -265,9 +275,9 @@ class ApproxTuner(Generic[T]):
# By default, keep_threshold == tuner_threshold
self.keep_threshold = qos_keep_threshold or qos_tuner_threshold
if is_threshold_relative:
baseline_qos, _ = self.app.measure_qos_cost({}, False)
qos_tuner_threshold = baseline_qos - qos_tuner_threshold
self.keep_threshold = baseline_qos - self.keep_threshold
self.baseline_qos, _ = self.app.measure_qos_cost({}, False)
qos_tuner_threshold = self.baseline_qos - qos_tuner_threshold
self.keep_threshold = self.baseline_qos - self.keep_threshold
opentuner_args.test_limit = max_iter
msg_logger.info(
"Tuner QoS threshold: %f; keeping configurations with QoS >= %f",
......
......@@ -34,7 +34,7 @@ app = TorchApp(
)
baseline, _ = app.measure_qos_cost({}, False)
tuner = app.get_tuner()
tuner.tune(100, 2.1, 3.0, True, 50, cost_model="cost_linear", qos_model="qos_p1")
tuner.tune(500, 2.1, 3.0, True, 20, cost_model="cost_linear", qos_model="qos_p1")
tuner.dump_configs("tuner_results/test/configs.json")
fig = tuner.plot_configs()
fig = tuner.plot_configs(show_qos_loss=True)
fig.savefig("tuner_results/test/configs.png", dpi=300)
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment