Skip to content
Snippets Groups Projects
Commit fe456002 authored by Yifan Zhao's avatar Yifan Zhao
Browse files

Put up framework for modeled tuning

parent f6e3cf2f
No related branches found
No related tags found
No related merge requests found
import abc import abc
import logging import logging
from pathlib import Path from pathlib import Path
from typing import Dict, Generic, List, NamedTuple, Optional, Tuple, TypeVar, Union from typing import Dict, Generic, List, Optional, Tuple, Type, TypeVar, Union
import matplotlib.pyplot as plt import matplotlib.pyplot as plt
import numpy as np import numpy as np
from opentuner.measurement.interface import MeasurementInterface from opentuner.measurement.interface import MeasurementInterface
from opentuner.resultsdb.models import Configuration, Result from opentuner.search.manipulator import (ConfigurationManipulator,
from opentuner.search.manipulator import ConfigurationManipulator, EnumParameter EnumParameter)
from ._logging import override_opentuner_config from ._logging import override_opentuner_config
from ._pareto import is_pareto_efficient from ._pareto import is_pareto_efficient
...@@ -61,32 +61,33 @@ class ApproxApp(abc.ABC): ...@@ -61,32 +61,33 @@ class ApproxApp(abc.ABC):
class Config: class Config:
def __init__( def __init__(
self, qos: float, calib_qos: Optional[float], perf: float, knobs: KnobsT self, qos: float, perf: float, knobs: KnobsT, calib_qos: Optional[float] = None
) -> None: ) -> None:
self.qos = qos self.qos = qos
self.calib_qos = calib_qos
self.perf = perf self.perf = perf
self.knobs = knobs self.knobs = knobs
self.calib_qos: Optional[float] = calib_qos
T = TypeVar("T", bound=Config) T = TypeVar("T", bound=Config)
# ApproxTuner is generic over the type of the config # IOpenTuner is generic over the type of the config
# So that the user can use custom Config inherited from Config # So that the user can use custom Config inherited from Config
# (in which case they need to override `get_all_configs_from_db`). # (in which case they need to override `get_all_configs_from_db`).
class ApproxTuner(Generic[T]): class ApproxTuner(Generic[T]):
def __init__(self, app: ApproxApp) -> None: def __init__(self, app: ApproxApp) -> None:
self.app = app self.app = app
self._tuned = False
self.all_configs = [] self.all_configs = []
self.kept_configs = [] self.kept_configs = []
self.best_configs = [] self.best_configs = []
# The following will be filled after self.tune() is called
self.keep_threshold = None self.keep_threshold = None
self._db = None
@property @property
def tuned(self) -> bool: def tuned(self) -> bool:
return not self._db is None return self._tuned
def tune( def tune(
self, self,
...@@ -95,54 +96,44 @@ class ApproxTuner(Generic[T]): ...@@ -95,54 +96,44 @@ class ApproxTuner(Generic[T]):
qos_keep_threshold: Optional[float] = None, qos_keep_threshold: Optional[float] = None,
is_threshold_relative: bool = False, is_threshold_relative: bool = False,
take_best_n: Optional[int] = None, take_best_n: Optional[int] = None,
calibrate: bool = True calibrate: bool = True,
**kwargs
# TODO: more parameters + opentuner param forwarding # TODO: more parameters + opentuner param forwarding
) -> List[Config]: ) -> List[T]:
"""Generate an optimal set of approximation configurations for the model."""
from opentuner.tuningrunmain import TuningRunMain from opentuner.tuningrunmain import TuningRunMain
from ._dbloader import read_opentuner_db from ._dbloader import read_opentuner_db
# By default, keep_threshold == tuner_threshold
opentuner_args = opentuner_default_args() opentuner_args = opentuner_default_args()
qos_keep_threshold = qos_keep_threshold or qos_tuner_threshold tuner = self._get_tuner_interface(
if is_threshold_relative: opentuner_args,
baseline_qos, _ = self.app.measure_qos_perf({}, False) max_iter,
qos_tuner_threshold = baseline_qos - qos_tuner_threshold qos_tuner_threshold,
qos_keep_threshold = baseline_qos - qos_keep_threshold qos_keep_threshold,
opentuner_args.test_limit = max_iter is_threshold_relative,
tuner = TunerInterface( self._get_app_kwargs(**kwargs),
opentuner_args, self.app, qos_tuner_threshold, qos_keep_threshold, max_iter,
) )
assert self.keep_threshold is not None
trm = TuningRunMain(tuner, opentuner_args) trm = TuningRunMain(tuner, opentuner_args)
# TuningRunMain.__init__ initializes its own logger, so we'll override it and use ours # TuningRunMain.__init__ initializes its own logger, so we'll override it and use ours
override_opentuner_config() override_opentuner_config()
# This is where opentuner runs # This is where opentuner runs
trm.main() trm.main()
# Parse and store results # Parse and store results
self._db = opentuner_args.database self._tuned = True
self.keep_threshold = qos_keep_threshold config_ty = self._get_config_class()
self.all_configs = list( self.all_configs = [
self.get_all_configs_from_db(read_opentuner_db(self._db)) config_ty(result.accuracy, result.time, configuration.data)
) for result, configuration in read_opentuner_db(opentuner_args.database)
]
self.kept_configs = [ self.kept_configs = [
cfg for cfg in self.all_configs if cfg.qos > qos_keep_threshold cfg for cfg in self.all_configs if cfg.qos > self.keep_threshold
] ]
self.best_configs = self.take_best_configs(self.kept_configs, take_best_n) self.best_configs = self.take_best_configs(self.kept_configs, take_best_n)
if calibrate: if calibrate:
self.calibrate_configs_(self.best_configs) self.calibrate_configs_(self.best_configs)
return self.best_configs return self.best_configs
@classmethod
def get_all_configs_from_db(
cls, results_configs: List[Tuple[Result, Configuration]]
) -> Tuple[T]:
return tuple(
Config(result.accuracy, None, result.time, configuration.data)
for result, configuration in results_configs
)
def calibrate_configs_(self, configs: List[T]): def calibrate_configs_(self, configs: List[T]):
from tqdm import tqdm from tqdm import tqdm
...@@ -154,9 +145,7 @@ class ApproxTuner(Generic[T]): ...@@ -154,9 +145,7 @@ class ApproxTuner(Generic[T]):
msg_logger.debug(f"Calibration: {cfg.qos} (mean) -> {cfg.calib_qos} (mean)") msg_logger.debug(f"Calibration: {cfg.qos} (mean) -> {cfg.calib_qos} (mean)")
@staticmethod @staticmethod
def take_best_configs( def take_best_configs(configs: List[T], n: Optional[int] = None) -> List[T]:
configs: List[Config], n: Optional[int] = None
) -> List[Config]:
points = np.array([[c.perf, c.qos] for c in configs]) points = np.array([[c.perf, c.qos] for c in configs])
taken_idx = is_pareto_efficient(points, take_n=n) taken_idx = is_pareto_efficient(points, take_n=n)
return [configs[i] for i in taken_idx] return [configs[i] for i in taken_idx]
...@@ -193,6 +182,38 @@ class ApproxTuner(Generic[T]): ...@@ -193,6 +182,38 @@ class ApproxTuner(Generic[T]):
ax.set_ylabel("speedup") ax.set_ylabel("speedup")
return fig return fig
def _get_tuner_interface(
self,
opentuner_args,
max_iter: int,
qos_tuner_threshold: float,
qos_keep_threshold: Optional[float],
is_threshold_relative: bool,
app_kwargs: dict,
) -> "TunerInterface":
# By default, keep_threshold == tuner_threshold
self.keep_threshold = qos_keep_threshold or qos_tuner_threshold
if is_threshold_relative:
baseline_qos, _ = self.app.measure_qos_perf({}, False)
qos_tuner_threshold = baseline_qos - qos_tuner_threshold
self.keep_threshold = baseline_qos - self.keep_threshold
opentuner_args.test_limit = max_iter
return TunerInterface(
opentuner_args,
self.app,
qos_tuner_threshold,
self.keep_threshold,
max_iter,
**app_kwargs,
)
def _get_app_kwargs(self, **kwargs):
return {}
@classmethod
def _get_config_class(cls) -> Type[Config]:
return Config
def opentuner_default_args(): def opentuner_default_args():
from opentuner import default_argparser from opentuner import default_argparser
...@@ -208,6 +229,7 @@ class TunerInterface(MeasurementInterface): ...@@ -208,6 +229,7 @@ class TunerInterface(MeasurementInterface):
tuner_thres: float, tuner_thres: float,
keep_thres: float, keep_thres: float,
test_limit: int, test_limit: int,
**app_kwargs,
): ):
from opentuner.measurement.inputmanager import FixedInputManager from opentuner.measurement.inputmanager import FixedInputManager
from opentuner.search.objective import ThresholdAccuracyMinimizeTime from opentuner.search.objective import ThresholdAccuracyMinimizeTime
...@@ -217,6 +239,7 @@ class TunerInterface(MeasurementInterface): ...@@ -217,6 +239,7 @@ class TunerInterface(MeasurementInterface):
self.tune_thres = tuner_thres self.tune_thres = tuner_thres
self.keep_thres = keep_thres self.keep_thres = keep_thres
self.pbar = tqdm(total=test_limit, leave=False) self.pbar = tqdm(total=test_limit, leave=False)
self.app_kwargs = app_kwargs
objective = ThresholdAccuracyMinimizeTime(tuner_thres) objective = ThresholdAccuracyMinimizeTime(tuner_thres)
input_manager = FixedInputManager(size=len(self.app.op_knobs)) input_manager = FixedInputManager(size=len(self.app.op_knobs))
......
import abc import abc
from typing import Callable, Dict, List, Tuple, Union import logging
from typing import Callable, Dict, List, Optional, Tuple, Type, Union
import torch import torch
from .approxapp import ApproxApp, KnobsT from .approxapp import ApproxApp, ApproxTuner, Config, KnobsT
msg_logger = logging.getLogger(__name__)
class ModeledApp(ApproxApp, abc.ABC): class ModeledApp(ApproxApp, abc.ABC):
...@@ -65,7 +68,7 @@ class ModeledApp(ApproxApp, abc.ABC): ...@@ -65,7 +68,7 @@ class ModeledApp(ApproxApp, abc.ABC):
f'"{qos_model}" is an invalid value for qos_model ' f'"{qos_model}" is an invalid value for qos_model '
f"(choose from {list(self._qos_models.keys())})" f"(choose from {list(self._qos_models.keys())})"
) )
qos = self._qos_models[qos_model].measure_qos(with_approxes, is_calibration) qos = self._qos_models[qos_model].measure_qos(with_approxes)
# Same goes for perf # Same goes for perf
if perf_model != "none": if perf_model != "none":
if perf_model not in self._perf_models: if perf_model not in self._perf_models:
...@@ -73,10 +76,13 @@ class ModeledApp(ApproxApp, abc.ABC): ...@@ -73,10 +76,13 @@ class ModeledApp(ApproxApp, abc.ABC):
f'"{perf_model}" is an invalid value for perf_model ' f'"{perf_model}" is an invalid value for perf_model '
f"(choose from {list(self._perf_models.keys())})" f"(choose from {list(self._perf_models.keys())})"
) )
perf = self._perf_models[perf_model].measure_perf(with_approxes, is_calibration) perf = self._perf_models[perf_model].measure_perf(with_approxes)
assert qos is not None and perf is not None assert qos is not None and perf is not None
return qos, perf return qos, perf
def get_tuner(self) -> "ApproxModeledTuner":
return ApproxModeledTuner(self)
class IPerfModel(abc.ABC): class IPerfModel(abc.ABC):
"""Abstract base class for models that provide performance prediction.""" """Abstract base class for models that provide performance prediction."""
...@@ -163,7 +169,7 @@ class QoSModelP1(IQoSModel): ...@@ -163,7 +169,7 @@ class QoSModelP1(IQoSModel):
def measure_qos(self, with_approxes: KnobsT) -> float: def measure_qos(self, with_approxes: KnobsT) -> float:
"""Implementation of model.""" """Implementation of model."""
pass return 0.0
class QoSModelP2(IQoSModel): class QoSModelP2(IQoSModel):
...@@ -189,4 +195,64 @@ class QoSModelP2(IQoSModel): ...@@ -189,4 +195,64 @@ class QoSModelP2(IQoSModel):
def measure_qos(self, with_approxes: KnobsT) -> float: def measure_qos(self, with_approxes: KnobsT) -> float:
"""Implementation of model.""" """Implementation of model."""
pass return 0.0
class ValConfig(Config):
def __init__(
self,
qos: float,
perf: float,
knobs: KnobsT,
calib_qos: Optional[float] = None,
validated_qos: Optional[float] = None,
) -> None:
super().__init__(qos, perf, knobs, calib_qos)
self.validated_qos = validated_qos
class ApproxModeledTuner(ApproxTuner):
def tune(
self,
max_iter: int,
qos_tuner_threshold: float,
qos_keep_threshold: Optional[float] = None,
is_threshold_relative: bool = False,
take_best_n: Optional[int] = None,
calibrate: bool = True,
validate: Optional[bool] = None,
perf_model: str = "none",
qos_model: str = "none",
) -> List[ValConfig]:
ret = super().tune(
max_iter=max_iter,
qos_tuner_threshold=qos_tuner_threshold,
qos_keep_threshold=qos_keep_threshold,
is_threshold_relative=is_threshold_relative,
take_best_n=take_best_n,
calibrate=calibrate,
perf_model=perf_model,
qos_model=qos_model,
)
if validate is None:
validate = qos_model != "none"
if validate:
self.validate_configs_(self.best_configs)
return ret
def validate_configs_(self, configs: List[ValConfig]):
from tqdm import tqdm
for cfg in tqdm(configs, leave=False):
cfg: ValConfig
if cfg.validated_qos is not None:
continue
cfg.validated_qos, _ = self.app.measure_qos_perf(cfg.knobs, False)
msg_logger.debug(f"Validation: {cfg.qos} (mean) -> {cfg.calib_qos} (mean)")
def _get_app_kwargs(self, perf_model: str, qos_model: str):
return {"perf_model": perf_model, "qos_model": qos_model}
@classmethod
def _get_config_class(cls) -> Type[Config]:
return ValConfig
...@@ -55,6 +55,12 @@ class TestTorchAppTuning(TorchAppSetUp): ...@@ -55,6 +55,12 @@ class TestTorchAppTuning(TorchAppSetUp):
if len(tuner.kept_configs) >= 10: if len(tuner.kept_configs) >= 10:
self.assertEqual(len(tuner.best_configs), 10) self.assertEqual(len(tuner.best_configs), 10)
def test_enum_models(self):
self.assertSetEqual(
set(model.name for model in self.app.get_models()),
{"perf_linear", "qos_p1", "qos_p2"},
)
class TestTorchAppTunerResult(TorchAppSetUp): class TestTorchAppTunerResult(TorchAppSetUp):
@classmethod @classmethod
...@@ -80,3 +86,30 @@ class TestTorchAppTunerResult(TorchAppSetUp): ...@@ -80,3 +86,30 @@ class TestTorchAppTunerResult(TorchAppSetUp):
configs = self.tuner.best_configs configs = self.tuner.best_configs
for c in configs: for c in configs:
self.assertAlmostEqual(c.calib_qos, c.qos) self.assertAlmostEqual(c.calib_qos, c.qos)
class TestModeledTuning(TorchAppSetUp):
@classmethod
def setUpClass(cls):
super().setUpClass()
cls.baseline, _ = cls.app.measure_qos_perf({}, False)
def test_qos_p1(self):
tuner = self.app.get_tuner()
tuner.tune(
100,
3.0,
is_threshold_relative=True,
perf_model="perf_linear",
qos_model="qos_p1",
)
def test_qos_p2(self):
tuner = self.app.get_tuner()
tuner.tune(
100,
3.0,
is_threshold_relative=True,
perf_model="perf_linear",
qos_model="qos_p2",
)
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment