Skip to content
Snippets Groups Projects
Commit d8da8106 authored by Yifan Zhao's avatar Yifan Zhao
Browse files

Cosmetic changes to logger output and progress bars

parent 5873347b
No related branches found
No related tags found
No related merge requests found
......@@ -7,7 +7,7 @@ import torch
from torch.utils.data.dataset import Dataset
RetT = Tuple[torch.Tensor, torch.Tensor]
msg_logger = logging.getLogger()
msg_logger = logging.getLogger(__name__)
PathLike = Union[Path, str]
......
import abc
import logging
from pathlib import Path
from typing import Dict, Generic, List, Optional, Tuple, Type, TypeVar
from typing import Callable, Dict, Generic, List, Optional, Tuple, Type, TypeVar
import matplotlib.pyplot as plt
import numpy as np
......@@ -111,6 +111,14 @@ class ApproxTuner(Generic[T]):
from ._dbloader import read_opentuner_db
n_ops, n_knobs = len(self.app.ops), len(self.app.knobs)
msg_logger.info(
"Started tuning app %s with %d ops and %d unique knob types",
self.app.name,
n_ops,
n_knobs,
)
msg_logger.info("At most %d iterations", max_iter)
opentuner_args = opentuner_default_args()
tuner = self._get_tuner_interface(
opentuner_args,
......@@ -124,6 +132,11 @@ class ApproxTuner(Generic[T]):
trm = TuningRunMain(tuner, opentuner_args)
# TuningRunMain.__init__ initializes its own logger, so we'll override it and use ours
override_opentuner_config()
msg_logger.info(
"Estimated size of search space: %d", trm.manipulator.search_space_size()
)
# A little bit of hack to get the _real_ progress when duplicated configs exist
tuner.set_progress_getter(lambda: trm.search_driver.test_count)
# This is where opentuner runs
trm.main()
# Parse and store results
......@@ -137,7 +150,16 @@ class ApproxTuner(Generic[T]):
cfg for cfg in self.all_configs if cfg.qos > self.keep_threshold
]
self.best_configs = self.take_best_configs(self.kept_configs, take_best_n)
msg_logger.info(
"Tuning finished with %d configs in total, "
"%d configs above keeping threshold, "
"and %d configs selected on tradeoff curve",
len(self.all_configs),
len(self.kept_configs),
len(self.best_configs),
)
if calibrate:
msg_logger.info("Calibrating configurations on calibration inputs")
self.calibrate_configs_(self.best_configs)
return self.best_configs
......@@ -206,6 +228,11 @@ class ApproxTuner(Generic[T]):
qos_tuner_threshold = baseline_qos - qos_tuner_threshold
self.keep_threshold = baseline_qos - self.keep_threshold
opentuner_args.test_limit = max_iter
msg_logger.info(
"Tuner QoS threshold: %f; keeping configurations with QoS >= %f",
qos_tuner_threshold,
self.keep_threshold,
)
return TunerInterface(
opentuner_args,
self.app,
......@@ -226,7 +253,9 @@ class ApproxTuner(Generic[T]):
def opentuner_default_args():
from opentuner import default_argparser
return default_argparser().parse_args([])
args = default_argparser().parse_args([])
args.no_dups = True # Don't print duplicated config warnings
return args
class TunerInterface(MeasurementInterface):
......@@ -259,6 +288,9 @@ class TunerInterface(MeasurementInterface):
objective=objective,
)
def set_progress_getter(self, getter: Callable[[], int]):
self.progress_getter = getter
def manipulator(self) -> ConfigurationManipulator:
"""Define the search space by creating a ConfigurationManipulator."""
manipulator = ConfigurationManipulator()
......@@ -275,9 +307,12 @@ class TunerInterface(MeasurementInterface):
qos, perf = self.app.measure_qos_perf(cfg, False, **self.app_kwargs)
# Print a debug message for each config in tuning and keep threshold
self.print_debug_config(qos, perf)
self.pbar.update()
self.pbar.update(self.progress_getter() - self.pbar.n)
return Result(time=perf, accuracy=qos)
def save_final_config(self, config):
self.pbar.close()
def print_debug_config(self, qos: float, perf: float):
gt_tune, gt_keep = qos > self.tune_thres, qos > self.keep_thres
if not gt_tune and not gt_keep:
......@@ -289,5 +324,5 @@ class TunerInterface(MeasurementInterface):
else:
kind = "tuning and keep"
msg_logger.debug(
f"Found config in {kind} threshold: QoS = {qos}, perf = {perf}"
f"Found config within {kind} threshold: QoS = {qos}, perf = {perf}"
)
......@@ -224,9 +224,8 @@ class QoSModelP1(IQoSModel):
self._try_append_save(self.storage, op, knob, delta_tensor)
super()._init()
@staticmethod
def _load(path: Path) -> Iterator[Tuple[str, str, torch.Tensor]]:
msg_logger.info(f"Found pickle at {path}")
def _load(self, path: Path) -> Iterator[Tuple[str, str, torch.Tensor]]:
msg_logger.info(f"Model {self.name} found saved model at {path}")
with path.open("rb") as f:
while True:
try:
......@@ -314,6 +313,7 @@ class QoSModelP2(IQoSModel):
)
else:
msg_logger.warning("Loaded profile does not have app name identifier")
msg_logger.info(f"Model {self.name} loaded saved model at {path}")
return df, baseline_qos
def _save(self, path: Path):
......@@ -360,9 +360,20 @@ class ApproxModeledTuner(ApproxTuner):
perf_model: str = "none",
qos_model: str = "none",
) -> List[ValConfig]:
qos_desc = (
"no model for qos" if qos_model == "none" else f'qos model "{qos_model}"'
)
perf_desc = (
"no model for performance"
if perf_model == "none"
else f'performance model "{perf_model}"'
)
msg_logger.info("Starting tuning with %s and %s", qos_desc, perf_desc)
if qos_model != "none":
msg_logger.info("Initializing qos model %s", qos_model)
self.app._init_model(qos_model)
if perf_model != "none":
msg_logger.info("Initializing performance model %s", perf_model)
self.app._init_model(perf_model)
ret = super().tune(
max_iter=max_iter,
......@@ -374,9 +385,13 @@ class ApproxModeledTuner(ApproxTuner):
perf_model=perf_model,
qos_model=qos_model,
)
if validate is None:
validate = qos_model != "none"
if validate:
if validate is None and qos_model != "none":
msg_logger.info(
'Validating configurations due to using qos model "%s"', qos_model
)
self.validate_configs_(self.best_configs)
elif validate:
msg_logger.info("Validating configurations as user requested")
self.validate_configs_(self.best_configs)
return ret
......
import logging
from collections import OrderedDict
from typing import Iterable, Tuple
import pandas
import torch
import torch.nn as nn
from .indexing import ModuleIndexer
_summary_used = False
msg_logger = logging.getLogger(__name__)
def get_flops(module: nn.Module, input_shape, output_shape):
......@@ -50,12 +53,12 @@ def get_flops(module: nn.Module, input_shape, output_shape):
handler = type_dispatch.get(type(module))
if not handler:
if not list(module.children()):
_print_once(f"Leaf module {module} cannot be handled")
_warn_once(f"Module {module} cannot be handled; its FLOPs will be estimated as 0")
return 0.0
try:
return handler()
except RuntimeError as e:
_print_once(f'Error "{e}" when handling {module}')
_warn_once(f'Error "{e}" when handling {module}; its FLOPs will be estimated as 0')
return 0.0
......@@ -92,7 +95,7 @@ def get_summary(model: nn.Module, model_args: Tuple) -> pandas.DataFrame:
params=n_params,
flops=flops,
trainable=trainable,
is_leaf=is_leaf
is_leaf=is_leaf,
)
def register_hook(module: nn.Module):
......@@ -127,9 +130,9 @@ def default_handle_sizes(value):
if isinstance(value, Iterable):
return [list(i.size()) for i in value]
except AttributeError as e:
_print_once(f"Cannot handle {type(value)}: error {e}")
_warn_once(f"Cannot get shape of {type(value)}: error {e}")
return None
_print_once(f"Cannot handle {type(value)}")
_warn_once(f"Cannot get shape of {type(value)}")
return None
......@@ -143,7 +146,7 @@ def _get_numel(shape):
return torch.prod(torch.tensor(shape)).item()
def _print_once(*args, **kwargs):
def _warn_once(*args, **kwargs):
if _summary_used:
return
print(*args, **kwargs)
msg_logger.warning(*args, **kwargs)
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment