diff --git a/predtuner/_dbloader.py b/predtuner/_dbloader.py index 860ff893e680ff82fa2f8fea9d23a10242dc7c69..834ca7d64ee48d65c2070587a1af163799621b3d 100644 --- a/predtuner/_dbloader.py +++ b/predtuner/_dbloader.py @@ -12,13 +12,17 @@ msg_logger = logging.getLogger(__name__) PathLike = Union[Path, str] -def read_opentuner_db(filepath: PathLike) -> List[Tuple[Result, Configuration]]: - filepath = Path(filepath) +def read_opentuner_db(filepath_or_uri: PathLike) -> List[Tuple[Result, Configuration]]: + if "://" in filepath_or_uri: + uri = filepath_or_uri + else: + filepath = Path(filepath_or_uri) + uri = f"sqlite:///{filepath}" try: - _, sess = resultsdb.connect(f"sqlite:///{filepath}") - except: - msg_logger.error("failed to load database: %s", filepath, exc_info=True) - return [] + _, sess = resultsdb.connect(uri) + except Exception as e: + msg_logger.error("Failed to load database: %s", filepath_or_uri, exc_info=True) + raise e session: Session = sess() latest_run_id = session.query(func.max(TuningRun.id)).all()[0][0] run_results = ( diff --git a/predtuner/approxapp.py b/predtuner/approxapp.py index 2c512566e9f210634dbba367193a4499ea716d50..e06c81fd2898be38677a69df9099a3a2ce72b97a 100644 --- a/predtuner/approxapp.py +++ b/predtuner/approxapp.py @@ -5,8 +5,7 @@ from typing import Dict, List, NamedTuple, Optional, Tuple, Union import matplotlib.pyplot as plt from opentuner.measurement.interface import MeasurementInterface -from opentuner.search.manipulator import (ConfigurationManipulator, - EnumParameter) +from opentuner.search.manipulator import ConfigurationManipulator, EnumParameter from ._logging import override_opentuner_config @@ -66,9 +65,14 @@ class Config(NamedTuple): class ApproxTuner: def __init__(self, app: ApproxApp) -> None: self.app = app - self.tune_sessions = [] - self.db = None + self.all_configs = [] + self.kept_configs = [] self.keep_threshold = None + self._db = None + + @property + def tuned(self) -> bool: + return not self._db is None def tune( self, @@ -77,17 +81,15 @@ class ApproxTuner: qos_keep_threshold: Optional[float] = None, accuracy_convention: str = "absolute" # TODO: this # TODO: more parameters + opentuner param forwarding - ): + ) -> List[Config]: """Generate an optimal set of approximation configurations for the model.""" - import socket - from opentuner.tuningrunmain import TuningRunMain + from ._dbloader import read_opentuner_db + # By default, keep_threshold == tuner_threshold opentuner_args = opentuner_default_args() qos_keep_threshold = qos_keep_threshold or qos_tuner_threshold - self.keep_threshold = qos_keep_threshold - self.db = opentuner_args.database or f'opentuner.db/{socket.gethostname()}.db' opentuner_args.test_limit = max_iter tuner = TunerInterface( opentuner_args, self.app, qos_tuner_threshold, qos_keep_threshold, max_iter, @@ -98,29 +100,36 @@ class ApproxTuner: # This is where opentuner runs trm.main() - def get_all_configs(self) -> List[Config]: - from ._dbloader import read_opentuner_db - - if self.db is None or self.keep_threshold is None: - raise RuntimeError( - f"No tuning session has been run; call self.tune() first." - ) - rets = read_opentuner_db(self.db) - return [ + # Parse and store results + self._db = opentuner_args.database + self.all_configs = [ Config(result.accuracy, result.time, configuration.data) - for result, configuration in rets - if result.accuracy > self.keep_threshold + for result, configuration in read_opentuner_db(self._db) + ] + self.kept_configs = [ + cfg for cfg in self.all_configs if cfg.qos > qos_keep_threshold ] + self.keep_threshold = qos_keep_threshold + return self.kept_configs def write_configs_to_dir(self, directory: PathLike): from jsonpickle import encode - encode(self.get_all_configs(), directory) + if not self.tuned: + raise RuntimeError( + f"No tuning session has been run; call self.tune() first." + ) + encode(self.kept_configs, directory) def plot_configs(self) -> plt.Figure: + if not self.tuned: + raise RuntimeError( + f"No tuning session has been run; call self.tune() first." + ) + _, perf = self.app.measure_qos_perf({}, False) fig, ax = plt.subplots() - confs = self.get_all_configs() + confs = self.kept_configs if not confs: return fig qos_speedup = [(c.qos, perf / c.perf) for c in confs] diff --git a/test/test_torchapp.py b/test/test_torchapp.py index 7fc0eea7b3402b8bfa29640b51c19c333faa315d..c16af2fbfb92fd9ffcf5c036da8583c1ef0d021d 100644 --- a/test/test_torchapp.py +++ b/test/test_torchapp.py @@ -52,6 +52,6 @@ class TestTorchApp(unittest.TestCase): baseline, _ = app.measure_qos_perf({}, False) tuner = app.get_tuner() tuner.tune(100, baseline - 3.0) - configs = tuner.get_all_configs() + configs = tuner.kept_configs for conf in configs: self.assertTrue(conf.qos > baseline - 3.0)