Skip to content
Snippets Groups Projects
Commit c7cd0f44 authored by Yifan Zhao's avatar Yifan Zhao
Browse files

Fixed problems with config validation and calibration

parent 3b38b3b8
No related branches found
No related tags found
No related merge requests found
......@@ -147,7 +147,7 @@ class ApproxTuner(Generic[T]):
is_threshold_relative: bool = False,
take_best_n: Optional[int] = None,
test_configs: bool = True,
**kwargs
app_kwargs: dict = None
# TODO: more parameters + opentuner param forwarding
) -> List[T]:
from opentuner.tuningrunmain import TuningRunMain
......@@ -169,7 +169,7 @@ class ApproxTuner(Generic[T]):
qos_tuner_threshold,
qos_keep_threshold,
is_threshold_relative,
self._get_app_kwargs(**kwargs),
app_kwargs or {},
)
assert self.keep_threshold is not None
trm = TuningRunMain(tuner, opentuner_args)
......@@ -202,19 +202,33 @@ class ApproxTuner(Generic[T]):
len(self.best_configs),
)
if test_configs:
msg_logger.info("Checking configurations on test inputs")
self.test_configs_(self.best_configs)
msg_logger.info("Calibrating configurations on test inputs")
self.best_configs = self.test_configs(self.best_configs)
return self.best_configs
def test_configs_(self, configs: List[T]):
def test_configs(self, configs: List[Config]):
from copy import deepcopy
from tqdm import tqdm
assert self.keep_threshold is not None
if not configs:
return []
ret_configs = []
total_error = 0
for cfg in tqdm(configs, leave=False):
cfg: T
if cfg.test_qos is not None:
continue
cfg = deepcopy(cfg)
assert cfg.test_qos is None
cfg.test_qos, _ = self.app.measure_qos_cost(cfg.knobs, True)
msg_logger.debug(f"Calibration: {cfg.qos} (mean) -> {cfg.test_qos} (mean)")
total_error += abs(cfg.qos - cfg.test_qos)
if cfg.test_qos > self.keep_threshold:
ret_configs.append(cfg)
else:
msg_logger.debug("Config removed")
mean_err = total_error / len(configs)
msg_logger.info("QoS mean abs difference of calibration: %f", mean_err)
return ret_configs
@staticmethod
def take_best_configs(configs: List[T], n: Optional[int] = None) -> List[T]:
......@@ -294,9 +308,6 @@ class ApproxTuner(Generic[T]):
**app_kwargs,
)
def _get_app_kwargs(self, **kwargs):
return {}
@classmethod
def _get_config_class(cls) -> Type[Config]:
return Config
......
......@@ -192,6 +192,8 @@ class QoSModelP1(IQoSModel):
qos_metric: Callable[[torch.Tensor], float],
storage: PathLike = None,
) -> None:
from torch.nn.functional import softmax
super().__init__()
self.app = app
self.output_f = tensor_output_getter
......@@ -387,32 +389,56 @@ class ApproxModeledTuner(ApproxTuner):
qos_keep_threshold=qos_keep_threshold,
is_threshold_relative=is_threshold_relative,
take_best_n=take_best_n,
test_configs=test_configs,
cost_model=cost_model,
qos_model=qos_model,
test_configs=False, # Test configs below by ourselves
app_kwargs={"cost_model": cost_model, "qos_model": qos_model}
)
if validate_configs is None and qos_model != "none":
msg_logger.info(
'Validating configurations due to using qos model "%s"', qos_model
)
self.validate_configs_(self.best_configs)
self.best_configs = self._update_configs(self.best_configs, False)
elif validate_configs:
msg_logger.info("Validating configurations as user requested")
self.validate_configs_(self.best_configs)
self.best_configs = self._update_configs(self.best_configs, False)
if test_configs:
msg_logger.info("Calibrating configurations on test inputs")
self.best_configs = self._update_configs(self.best_configs, True)
return ret
def validate_configs_(self, configs: List[ValConfig]):
def _update_configs(self, configs: List[ValConfig], test_mode: bool):
from copy import deepcopy
from tqdm import tqdm
assert self.keep_threshold is not None
if not configs:
msg_logger.info("No configurations found.")
return []
ret_configs = []
total_error = 0
for cfg in tqdm(configs, leave=False):
cfg: ValConfig
if cfg.validated_qos is not None:
continue
cfg.validated_qos, _ = self.app.measure_qos_cost(cfg.knobs, False)
msg_logger.debug(f"Validation: {cfg.qos} (mean) -> {cfg.test_qos} (mean)")
def _get_app_kwargs(self, cost_model: str, qos_model: str):
return {"cost_model": cost_model, "qos_model": qos_model}
cfg = deepcopy(cfg)
qos, _ = self.app.measure_qos_cost(cfg.knobs, test_mode)
if test_mode:
assert cfg.test_qos is None
cfg.test_qos = qos
msg_logger.debug(f"Calibration: {cfg.qos} (mean) -> {qos} (mean)")
else:
assert cfg.validated_qos is None
cfg.validated_qos = qos
msg_logger.debug(f"Validation: {cfg.qos} (mean) -> {qos} (mean)")
total_error += abs(cfg.qos - qos)
if qos > self.keep_threshold:
ret_configs.append(cfg)
else:
msg_logger.debug("Config removed")
mean_err = total_error / len(configs)
if test_mode:
msg_logger.info("QoS mean abs difference of calibration: %f", mean_err)
else:
msg_logger.info("QoS mean abs difference of validation: %f", mean_err)
msg_logger.info("%d of %d configs remain", len(ret_configs), len(configs))
return ret_configs
@classmethod
def _get_config_class(cls) -> Type[Config]:
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment