Skip to content
Snippets Groups Projects
Commit fb889b56 authored by Yifan Zhao's avatar Yifan Zhao
Browse files

Changed how tuner stores its configuration

parent f602d4b5
No related branches found
No related tags found
No related merge requests found
......@@ -12,13 +12,17 @@ msg_logger = logging.getLogger(__name__)
PathLike = Union[Path, str]
def read_opentuner_db(filepath: PathLike) -> List[Tuple[Result, Configuration]]:
filepath = Path(filepath)
def read_opentuner_db(filepath_or_uri: PathLike) -> List[Tuple[Result, Configuration]]:
if "://" in filepath_or_uri:
uri = filepath_or_uri
else:
filepath = Path(filepath_or_uri)
uri = f"sqlite:///{filepath}"
try:
_, sess = resultsdb.connect(f"sqlite:///{filepath}")
except:
msg_logger.error("failed to load database: %s", filepath, exc_info=True)
return []
_, sess = resultsdb.connect(uri)
except Exception as e:
msg_logger.error("Failed to load database: %s", filepath_or_uri, exc_info=True)
raise e
session: Session = sess()
latest_run_id = session.query(func.max(TuningRun.id)).all()[0][0]
run_results = (
......
......@@ -5,8 +5,7 @@ from typing import Dict, List, NamedTuple, Optional, Tuple, Union
import matplotlib.pyplot as plt
from opentuner.measurement.interface import MeasurementInterface
from opentuner.search.manipulator import (ConfigurationManipulator,
EnumParameter)
from opentuner.search.manipulator import ConfigurationManipulator, EnumParameter
from ._logging import override_opentuner_config
......@@ -66,9 +65,14 @@ class Config(NamedTuple):
class ApproxTuner:
def __init__(self, app: ApproxApp) -> None:
self.app = app
self.tune_sessions = []
self.db = None
self.all_configs = []
self.kept_configs = []
self.keep_threshold = None
self._db = None
@property
def tuned(self) -> bool:
return not self._db is None
def tune(
self,
......@@ -77,17 +81,15 @@ class ApproxTuner:
qos_keep_threshold: Optional[float] = None,
accuracy_convention: str = "absolute" # TODO: this
# TODO: more parameters + opentuner param forwarding
):
) -> List[Config]:
"""Generate an optimal set of approximation configurations for the model."""
import socket
from opentuner.tuningrunmain import TuningRunMain
from ._dbloader import read_opentuner_db
# By default, keep_threshold == tuner_threshold
opentuner_args = opentuner_default_args()
qos_keep_threshold = qos_keep_threshold or qos_tuner_threshold
self.keep_threshold = qos_keep_threshold
self.db = opentuner_args.database or f'opentuner.db/{socket.gethostname()}.db'
opentuner_args.test_limit = max_iter
tuner = TunerInterface(
opentuner_args, self.app, qos_tuner_threshold, qos_keep_threshold, max_iter,
......@@ -98,29 +100,36 @@ class ApproxTuner:
# This is where opentuner runs
trm.main()
def get_all_configs(self) -> List[Config]:
from ._dbloader import read_opentuner_db
if self.db is None or self.keep_threshold is None:
raise RuntimeError(
f"No tuning session has been run; call self.tune() first."
)
rets = read_opentuner_db(self.db)
return [
# Parse and store results
self._db = opentuner_args.database
self.all_configs = [
Config(result.accuracy, result.time, configuration.data)
for result, configuration in rets
if result.accuracy > self.keep_threshold
for result, configuration in read_opentuner_db(self._db)
]
self.kept_configs = [
cfg for cfg in self.all_configs if cfg.qos > qos_keep_threshold
]
self.keep_threshold = qos_keep_threshold
return self.kept_configs
def write_configs_to_dir(self, directory: PathLike):
from jsonpickle import encode
encode(self.get_all_configs(), directory)
if not self.tuned:
raise RuntimeError(
f"No tuning session has been run; call self.tune() first."
)
encode(self.kept_configs, directory)
def plot_configs(self) -> plt.Figure:
if not self.tuned:
raise RuntimeError(
f"No tuning session has been run; call self.tune() first."
)
_, perf = self.app.measure_qos_perf({}, False)
fig, ax = plt.subplots()
confs = self.get_all_configs()
confs = self.kept_configs
if not confs:
return fig
qos_speedup = [(c.qos, perf / c.perf) for c in confs]
......
......@@ -52,6 +52,6 @@ class TestTorchApp(unittest.TestCase):
baseline, _ = app.measure_qos_perf({}, False)
tuner = app.get_tuner()
tuner.tune(100, baseline - 3.0)
configs = tuner.get_all_configs()
configs = tuner.kept_configs
for conf in configs:
self.assertTrue(conf.qos > baseline - 3.0)
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment