Skip to content
Snippets Groups Projects
Commit fb889b56 authored by Yifan Zhao's avatar Yifan Zhao
Browse files

Changed how tuner stores its configuration

parent f602d4b5
No related branches found
No related tags found
No related merge requests found
...@@ -12,13 +12,17 @@ msg_logger = logging.getLogger(__name__) ...@@ -12,13 +12,17 @@ msg_logger = logging.getLogger(__name__)
PathLike = Union[Path, str] PathLike = Union[Path, str]
def read_opentuner_db(filepath: PathLike) -> List[Tuple[Result, Configuration]]: def read_opentuner_db(filepath_or_uri: PathLike) -> List[Tuple[Result, Configuration]]:
filepath = Path(filepath) if "://" in filepath_or_uri:
uri = filepath_or_uri
else:
filepath = Path(filepath_or_uri)
uri = f"sqlite:///{filepath}"
try: try:
_, sess = resultsdb.connect(f"sqlite:///{filepath}") _, sess = resultsdb.connect(uri)
except: except Exception as e:
msg_logger.error("failed to load database: %s", filepath, exc_info=True) msg_logger.error("Failed to load database: %s", filepath_or_uri, exc_info=True)
return [] raise e
session: Session = sess() session: Session = sess()
latest_run_id = session.query(func.max(TuningRun.id)).all()[0][0] latest_run_id = session.query(func.max(TuningRun.id)).all()[0][0]
run_results = ( run_results = (
......
...@@ -5,8 +5,7 @@ from typing import Dict, List, NamedTuple, Optional, Tuple, Union ...@@ -5,8 +5,7 @@ from typing import Dict, List, NamedTuple, Optional, Tuple, Union
import matplotlib.pyplot as plt import matplotlib.pyplot as plt
from opentuner.measurement.interface import MeasurementInterface from opentuner.measurement.interface import MeasurementInterface
from opentuner.search.manipulator import (ConfigurationManipulator, from opentuner.search.manipulator import ConfigurationManipulator, EnumParameter
EnumParameter)
from ._logging import override_opentuner_config from ._logging import override_opentuner_config
...@@ -66,9 +65,14 @@ class Config(NamedTuple): ...@@ -66,9 +65,14 @@ class Config(NamedTuple):
class ApproxTuner: class ApproxTuner:
def __init__(self, app: ApproxApp) -> None: def __init__(self, app: ApproxApp) -> None:
self.app = app self.app = app
self.tune_sessions = [] self.all_configs = []
self.db = None self.kept_configs = []
self.keep_threshold = None self.keep_threshold = None
self._db = None
@property
def tuned(self) -> bool:
return not self._db is None
def tune( def tune(
self, self,
...@@ -77,17 +81,15 @@ class ApproxTuner: ...@@ -77,17 +81,15 @@ class ApproxTuner:
qos_keep_threshold: Optional[float] = None, qos_keep_threshold: Optional[float] = None,
accuracy_convention: str = "absolute" # TODO: this accuracy_convention: str = "absolute" # TODO: this
# TODO: more parameters + opentuner param forwarding # TODO: more parameters + opentuner param forwarding
): ) -> List[Config]:
"""Generate an optimal set of approximation configurations for the model.""" """Generate an optimal set of approximation configurations for the model."""
import socket
from opentuner.tuningrunmain import TuningRunMain from opentuner.tuningrunmain import TuningRunMain
from ._dbloader import read_opentuner_db
# By default, keep_threshold == tuner_threshold # By default, keep_threshold == tuner_threshold
opentuner_args = opentuner_default_args() opentuner_args = opentuner_default_args()
qos_keep_threshold = qos_keep_threshold or qos_tuner_threshold qos_keep_threshold = qos_keep_threshold or qos_tuner_threshold
self.keep_threshold = qos_keep_threshold
self.db = opentuner_args.database or f'opentuner.db/{socket.gethostname()}.db'
opentuner_args.test_limit = max_iter opentuner_args.test_limit = max_iter
tuner = TunerInterface( tuner = TunerInterface(
opentuner_args, self.app, qos_tuner_threshold, qos_keep_threshold, max_iter, opentuner_args, self.app, qos_tuner_threshold, qos_keep_threshold, max_iter,
...@@ -98,29 +100,36 @@ class ApproxTuner: ...@@ -98,29 +100,36 @@ class ApproxTuner:
# This is where opentuner runs # This is where opentuner runs
trm.main() trm.main()
def get_all_configs(self) -> List[Config]: # Parse and store results
from ._dbloader import read_opentuner_db self._db = opentuner_args.database
self.all_configs = [
if self.db is None or self.keep_threshold is None:
raise RuntimeError(
f"No tuning session has been run; call self.tune() first."
)
rets = read_opentuner_db(self.db)
return [
Config(result.accuracy, result.time, configuration.data) Config(result.accuracy, result.time, configuration.data)
for result, configuration in rets for result, configuration in read_opentuner_db(self._db)
if result.accuracy > self.keep_threshold ]
self.kept_configs = [
cfg for cfg in self.all_configs if cfg.qos > qos_keep_threshold
] ]
self.keep_threshold = qos_keep_threshold
return self.kept_configs
def write_configs_to_dir(self, directory: PathLike): def write_configs_to_dir(self, directory: PathLike):
from jsonpickle import encode from jsonpickle import encode
encode(self.get_all_configs(), directory) if not self.tuned:
raise RuntimeError(
f"No tuning session has been run; call self.tune() first."
)
encode(self.kept_configs, directory)
def plot_configs(self) -> plt.Figure: def plot_configs(self) -> plt.Figure:
if not self.tuned:
raise RuntimeError(
f"No tuning session has been run; call self.tune() first."
)
_, perf = self.app.measure_qos_perf({}, False) _, perf = self.app.measure_qos_perf({}, False)
fig, ax = plt.subplots() fig, ax = plt.subplots()
confs = self.get_all_configs() confs = self.kept_configs
if not confs: if not confs:
return fig return fig
qos_speedup = [(c.qos, perf / c.perf) for c in confs] qos_speedup = [(c.qos, perf / c.perf) for c in confs]
......
...@@ -52,6 +52,6 @@ class TestTorchApp(unittest.TestCase): ...@@ -52,6 +52,6 @@ class TestTorchApp(unittest.TestCase):
baseline, _ = app.measure_qos_perf({}, False) baseline, _ = app.measure_qos_perf({}, False)
tuner = app.get_tuner() tuner = app.get_tuner()
tuner.tune(100, baseline - 3.0) tuner.tune(100, baseline - 3.0)
configs = tuner.get_all_configs() configs = tuner.kept_configs
for conf in configs: for conf in configs:
self.assertTrue(conf.qos > baseline - 3.0) self.assertTrue(conf.qos > baseline - 3.0)
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment