Skip to content
Snippets Groups Projects
Commit dfdd5318 authored by Yifan Zhao's avatar Yifan Zhao
Browse files

Plot configs better

parent def1ad89
No related branches found
No related tags found
No related merge requests found
...@@ -133,6 +133,7 @@ class ApproxTuner(Generic[T]): ...@@ -133,6 +133,7 @@ class ApproxTuner(Generic[T]):
self.best_configs = [] self.best_configs = []
# The following will be filled after self.tune() is called # The following will be filled after self.tune() is called
self.keep_threshold = None self.keep_threshold = None
self.baseline_qos = None
@property @property
def tuned(self) -> bool: def tuned(self) -> bool:
...@@ -235,22 +236,31 @@ class ApproxTuner(Generic[T]): ...@@ -235,22 +236,31 @@ class ApproxTuner(Generic[T]):
with filepath.open("w") as f: with filepath.open("w") as f:
f.write(encode(self.best_configs, indent=2)) f.write(encode(self.best_configs, indent=2))
def plot_configs(self) -> plt.Figure: def plot_configs(
self, show_qos_loss: bool = False, connect_best_points: bool = False
) -> plt.Figure:
if not self.tuned: if not self.tuned:
raise RuntimeError( raise RuntimeError(
f"No tuning session has been run; call self.tune() first." f"No tuning session has been run; call self.tune() first."
) )
def get_points(confs):
sorted_points = np.array(
sorted([c.qos_speedup for c in confs], key=lambda p: p[0])
).T
if show_qos_loss:
sorted_points[0] = self.baseline_qos - sorted_points[0]
return sorted_points
fig, ax = plt.subplots() fig, ax = plt.subplots()
confs = self.kept_configs kept_confs = get_points(self.kept_configs)
if not confs: best_confs = get_points(self.best_configs)
return fig ax.plot(kept_confs[0], kept_confs[1], "o", label="valid")
qos_speedup = [c.qos_speedup for c in confs] mode = "-o" if connect_best_points else "o"
qoses, speedups = zip(*sorted(qos_speedup, key=lambda p: p[0])) ax.plot(best_confs[0], best_confs[1], mode, label="best")
ax.plot(qoses, speedups) ax.set_xlabel("QoS Loss" if show_qos_loss else "QoS")
ax.scatter(qoses, speedups) ax.set_ylabel("Speedup (x)")
ax.set_xlabel("qos") ax.legend()
ax.set_ylabel("speedup")
return fig return fig
def _get_tuner_interface( def _get_tuner_interface(
...@@ -265,9 +275,9 @@ class ApproxTuner(Generic[T]): ...@@ -265,9 +275,9 @@ class ApproxTuner(Generic[T]):
# By default, keep_threshold == tuner_threshold # By default, keep_threshold == tuner_threshold
self.keep_threshold = qos_keep_threshold or qos_tuner_threshold self.keep_threshold = qos_keep_threshold or qos_tuner_threshold
if is_threshold_relative: if is_threshold_relative:
baseline_qos, _ = self.app.measure_qos_cost({}, False) self.baseline_qos, _ = self.app.measure_qos_cost({}, False)
qos_tuner_threshold = baseline_qos - qos_tuner_threshold qos_tuner_threshold = self.baseline_qos - qos_tuner_threshold
self.keep_threshold = baseline_qos - self.keep_threshold self.keep_threshold = self.baseline_qos - self.keep_threshold
opentuner_args.test_limit = max_iter opentuner_args.test_limit = max_iter
msg_logger.info( msg_logger.info(
"Tuner QoS threshold: %f; keeping configurations with QoS >= %f", "Tuner QoS threshold: %f; keeping configurations with QoS >= %f",
......
...@@ -34,7 +34,7 @@ app = TorchApp( ...@@ -34,7 +34,7 @@ app = TorchApp(
) )
baseline, _ = app.measure_qos_cost({}, False) baseline, _ = app.measure_qos_cost({}, False)
tuner = app.get_tuner() tuner = app.get_tuner()
tuner.tune(100, 2.1, 3.0, True, 50, cost_model="cost_linear", qos_model="qos_p1") tuner.tune(500, 2.1, 3.0, True, 20, cost_model="cost_linear", qos_model="qos_p1")
tuner.dump_configs("tuner_results/test/configs.json") tuner.dump_configs("tuner_results/test/configs.json")
fig = tuner.plot_configs() fig = tuner.plot_configs(show_qos_loss=True)
fig.savefig("tuner_results/test/configs.png", dpi=300) fig.savefig("tuner_results/test/configs.png", dpi=300)
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment