Skip to content
Snippets Groups Projects
Commit f935e7ab authored by Yifan Zhao's avatar Yifan Zhao
Browse files

Load the results written in opentuner database

parent 9930b3c7
No related branches found
No related tags found
No related merge requests found
import logging
from pathlib import Path
from typing import List, Tuple, Union
from opentuner import resultsdb
from opentuner.resultsdb.models import Configuration, Result, TuningRun
from sqlalchemy.orm.session import Session
from sqlalchemy.sql.expression import func
msg_logger = logging.getLogger(__name__)
PathLike = Union[Path, str]
def read_opentuner_db(filepath: PathLike) -> List[Tuple[Result, Configuration]]:
filepath = Path(filepath)
try:
_, sess = resultsdb.connect(f"sqlite:///{filepath}")
except:
msg_logger.error("failed to load database: %s", filepath, exc_info=True)
return []
session: Session = sess()
latest_run_id = session.query(func.max(TuningRun.id)).all()[0][0]
run_results = (
session.query(Result, Configuration)
.filter_by(tuning_run_id=latest_run_id)
.filter(Result.configuration_id == Configuration.id)
.all()
)
return run_results
import abc
import logging
from pathlib import Path
from typing import Dict, List, Optional, Tuple, Union
from typing import Dict, List, NamedTuple, Optional, Tuple, Union
import matplotlib.pyplot as plt
from opentuner.measurement.interface import MeasurementInterface
from opentuner.search.manipulator import ConfigurationManipulator, EnumParameter
from opentuner.search.manipulator import (ConfigurationManipulator,
EnumParameter)
msg_logger = logging.getLogger(__name__)
KnobsT = Dict[str, str]
......@@ -53,14 +55,17 @@ class ApproxApp(abc.ABC):
return ""
class Config:
pass # TODO: work this out later
class Config(NamedTuple):
qos: float
perf: float
knobs: KnobsT
class ApproxTuner:
def __init__(self, app: ApproxApp) -> None:
self.app = app
self.tune_sessions = []
self.db = None
def tune(
self,
......@@ -71,11 +76,14 @@ class ApproxTuner:
# TODO: more parameters + opentuner param forwarding
):
"""Generate an optimal set of approximation configurations for the model."""
import socket
from opentuner.tuningrunmain import TuningRunMain
# By default, keep_threshold == tuner_threshold
qos_keep_threshold = qos_keep_threshold or qos_tuner_threshold
opentuner_args = opentuner_default_args()
self.db = opentuner_args.database or f'opentuner.db/{socket.gethostname()}.db'
opentuner_args.test_limit = max_iter
tuner = TunerInterface(
opentuner_args, self.app, qos_tuner_threshold, qos_keep_threshold, max_iter,
......@@ -83,20 +91,35 @@ class ApproxTuner:
# This is where opentuner runs
TuningRunMain(tuner, opentuner_args).main()
# More helpers for selecting a config omitted for brevity
def get_all_configs(self) -> List[Config]:
return [] # TODO: parse opentuner database (do they have helpers?)
# TODO
# Work out details of saving / loading
# Important to keep association between model, weights, and configs
# Especially when retraining is involved
def store_configs(self, path: PathLike):
pass
def load_configs(self, path: PathLike):
pass
from ._dbloader import read_opentuner_db
if self.db is None:
raise RuntimeError(
f"No tuning session has been run; call self.tune() first."
)
rets = read_opentuner_db(self.db)
return [
Config(result.accuracy, result.time, configuration.data)
for result, configuration in rets
]
def write_configs_to_dir(self, directory: PathLike):
from jsonpickle import encode
encode(self.get_all_configs(), directory)
def plot_configs(self) -> plt.Figure:
_, perf = self.app.measure_qos_perf({}, False)
fig, ax = plt.subplots()
confs = self.get_all_configs()
qos_speedup = [(c.qos, perf / c.perf) for c in confs]
qoses, speedups = zip(*sorted(qos_speedup, key=lambda p: p[0]))
ax.plot(qoses, speedups)
ax.scatter(qoses, speedups)
ax.set_xlabel("qos")
ax.set_ylabel("speedup")
return fig
def opentuner_default_args():
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment