-
Akash Kothari authoredAkash Kothari authored
run_tuner.py 12.95 KiB
#!/usr/bin/env python
#
# Development-time Tuner with Algorithmic Approximations:
# Approximations: Perforation, Sampling with varying knobs for rate, skip offset
import copy
import logging
import os
import shutil
import time
from pathlib import Path
from typing import List, Tuple
import numpy as np
import opentuner
from opentuner import ConfigurationManipulator, EnumParameter, MeasurementInterface
from opentuner.measurement.inputmanager import FixedInputManager
from opentuner.search.objective import ThresholdAccuracyMinimizeTime
from opentuner.tuningrunmain import TuningRunMain
from torch.nn import Module
from tqdm import tqdm
from exp import Benchmark, ConfigMeasurer, ExpState, TuningTime, batch_id, bench_tuner_data, is_dev_time
from models import get_all_output, networks, QoS
from toolkit import ConfigT
from toolkit.estimators import WeightedLinearQoSEstimator
from utils import Config, config, reapply_last_config
msg_logger = logging.getLogger(__name__)
use_proxy = False
n_promise_valid_runs = 30
confidence_level = 0.95
def init_proxy(ni: ConfigMeasurer, pickle_path: Path):
def acc_crit(inputs_):
return ni.get_qos(inputs_, ni.val_loader)
def threshold_eval(inputs_):
accs = np.array([acc_crit(x) for x in inputs_])
return ni.val_qos - accs.mean() < 3.0
def run_model(net: Module):
return get_all_output(net, ni.val_loader)
return WeightedLinearQoSEstimator(
ni.nas, run_model, acc_crit, threshold_eval, confidence_level, storage=pickle_path
)
class Timer:
def __init__(self, timer_state: TuningTime, timer_name: str):
self.timer_state = timer_state
self.name = timer_name
self.start = None
def __enter__(self):
self.start = time.time()
return self
def __exit__(self, *args):
end = time.time()
interval = end - self.start
self.timer_state.add_timer(self.name, interval)
class TunerDriver:
def __init__(self, bench: Benchmark):
self.bench = bench
msg_logger.info(f"Tuning for model {self.bench.model_name}")
# Initialize folder.
self._init_folder(bench)
# Take a snapshot of current code.
self.take_code_snapshot()
# Initialize network information and qos thresholds
self.net_info = ConfigMeasurer.init_from_bench(self.bench)
qoses = self.net_info.val_qos, self.net_info.test_qos
qos_type = self.net_info.val_qos.__class__
self.tuner_thres = qos_type.suggested_tuner_thresholds(self.net_info.val_qos)
self.val_thres = qos_type.suggested_val_threshold(self.net_info.val_qos)
self.test_thres = qos_type.suggested_test_threshold(self.net_info.test_qos)
# Tuner states.
self.states = ExpState(bench, qos_type, qoses)
# Current # of iteration. `ProxyTuner` will use this.
self.run_id, self.iter = 0, 0
# Initialize proxy.
if use_proxy:
self.proxy = init_proxy(self.net_info, self.bench.result_dir / 'proxy.pkl')
else:
self.proxy = None
@staticmethod
def _init_folder(bench: Benchmark):
def remove_file_or_folder(path: Path):
if path.is_dir():
shutil.rmtree(child)
elif path.is_file():
path.unlink() # Removes file despite the surprising name
pickle_path = bench.result_dir / 'proxy.pkl'
# Remove everything in result folder except pickle file
if bench.result_dir.is_dir():
msg_logger.warning(f"!Cleaning existing result dir = {bench.result_dir}")
for child in bench.result_dir.glob('*'):
if child == pickle_path:
continue
msg_logger.info(f" !Removing {child}")
remove_file_or_folder(child)
# Create result folder if it doesn't exist
if not bench.result_dir.is_dir():
msg_logger.info(f"Creating output directory = {bench.result_dir}")
os.makedirs(bench.result_dir)
def get_default_args(self):
args = opentuner.default_argparser().parse_args()
args.database = f"opentuner.db/{batch_id}.db"
args.test_limit = self.bench.autotuner_runs
parent = Path(args.database).parent
if not parent.is_dir():
os.makedirs(parent, exist_ok=True)
return args
def tuner_exec(self):
# Get default opentuner args
args = self.get_default_args()
# Start tuning for each threshold
for i, thres in enumerate(self.tuner_thres):
with Timer(self.states.timers, f"tuning_{i}"):
msg_logger.info(
f"Tuning goal: qos >= {thres}; keeping configs with qos >= {self.val_thres}"
)
tuner = ProxyTuner(args, self, thres, self.val_thres)
# TuningRunMain.__init__ initializes its own logger, so we'll reapply our settings.
tuning_main = TuningRunMain(tuner, args)
reapply_last_config()
# Unleash the tuner!
tuning_main.main()
# Remove tuner progress bar
tuner.pbar.close()
self.run_id += 1
self.iter = 0
# Postprocess configs
self.process_configs()
def calibrate_write_configs(self, configs: List[Config], is_test_set: bool):
write_to = self.states.tested_configs if is_test_set else self.states.validated_configs
gold_acc = self.net_info.test_qos if is_test_set else self.net_info.val_qos
for cfg in tqdm(configs, leave=False):
cfg = copy.deepcopy(cfg)
cfg: Config
flags = {k: v for k, v in enumerate(cfg.flags)}
measured_acc, confidence = self.net_info.actual_measure(
flags, cfg.total_runs, is_test_set, threshold=self.val_thres
)
prev_acc = cfg.avg_qos
cfg.update_acc(measured_acc, confidence, gold_acc)
new_acc = cfg.avg_qos
msg_logger.debug(f"{prev_acc} (mean) -> {new_acc} (mean)")
write_to.append(cfg)
write_to.finalize_dump()
@staticmethod
def filter_configs(
validation: List[Config], test: List[Config],
vali_threshold: QoS, test_threshold: QoS
) -> Tuple[List[Config], List[Config]]:
# Filter validation and test set by their respective thresholds
filtered_validation = [
c for c in validation if c.avg_loss <= vali_threshold
]
filtered_test = [
c for c in test if c.avg_loss <= test_threshold
]
# Test configs also need to be a subset of validation configs.
name_to_filtered = {x.fname: x for x in filtered_test}
intersect_names = set(list(name_to_filtered.keys())).intersection(
set((x.fname for x in filtered_validation))
)
filtered_test_ = [name_to_filtered[fname] for fname in intersect_names]
return filtered_validation, filtered_test_
def process_configs(self):
# Finalize all configs because tuning is done.
# (this may not do anything now but will in the future)
self.states.all_configs.finalize_dump()
all_configs = self.states.all_configs.configs
# Pre-filter configs by a wide pareto margin
filtered_configs = config.is_pareto_efficient(all_configs, ratio=0.05, n_min=50, n_max=50)
msg_logger.info(f"Prefilter yields {len(filtered_configs)} configs from {len(all_configs)}")
self.states.filtered_configs.finalize_dump(with_configs=filtered_configs)
# Calibrate prefiltered configs (validation step)
with Timer(self.states.timers, "validate"):
self.calibrate_write_configs(filtered_configs, is_test_set=False)
validated_configs = self.states.validated_configs.configs
# Calibrate prefiltered configs on test set (test step)
with Timer(self.states.timers, "test"):
self.calibrate_write_configs(filtered_configs, is_test_set=True)
tested_configs = self.states.tested_configs.configs
# Filter valid and test set configs by thresholds
valid_configs, test_configs = self.filter_configs(
validated_configs, tested_configs, self.val_thres, self.test_thres
)
self.states.valid_configs.finalize_dump(valid_configs)
self.states.test_configs.finalize_dump(test_configs)
# Finalize data input and plot everything.
self.states.finalize_plot()
def take_code_snapshot(self):
import git
msg_logger.info(f"Taking git snapshot")
ref_dir = self.bench.result_dir / "references"
os.mkdir(ref_dir)
# Write current git commit (SHA id)
repo = git.Repo(search_parent_directories=True)
sha = repo.head.object.hexsha
msg_logger.info(f"Current code is at commit {sha}")
with (ref_dir / 'git_commit.txt').open('w') as f:
f.write(sha)
# Also put all outstanding code change in a diff file.
# This way changes in all git-tracked files are captured.
t = repo.head.commit.tree
with (ref_dir / 'diff.txt').open('w') as f:
f.write(repo.git.diff(t))
def make_config_name(self) -> str:
return f"{self.bench.model_name}_{self.run_id}_{self.iter}"
def get_accuracy(self, cfg: ConfigT) -> Tuple[QoS, QoS, int]:
has_promise_flags = set(cfg.values()).intersection(set(range(1, 7 + 1)))
config_validation_runs = n_promise_valid_runs if has_promise_flags else 1
if use_proxy:
mean_acc, confidence_acc = self.net_info.proxy_estimate(cfg, self.proxy)
assert has_promise_flags or (mean_acc == confidence_acc)
else:
mean_acc, _ = self.net_info.actual_measure(cfg, 1, is_test_set=False)
confidence_acc = mean_acc
return mean_acc, confidence_acc, config_validation_runs
class ProxyTuner(MeasurementInterface):
def __init__(self, args, driver: TunerDriver, tuner_thres: QoS, accept_thres: QoS):
self.tuner_driver = driver
self.model_info = driver.net_info
self.bench = driver.bench
self.tuner_thres = tuner_thres
self.all_configs = driver.states.all_configs
self.pbar = tqdm(total=args.test_limit, leave=False)
objective = ThresholdAccuracyMinimizeTime(tuner_thres.to_scalar())
input_manager = FixedInputManager(size=driver.bench.get_n_layers())
super(ProxyTuner, self).__init__(
args, program_name=self.bench.model_name,
input_manager=input_manager, objective=objective
)
self.accept_thres = accept_thres
def manipulator(self) -> ConfigurationManipulator:
"""Define the search space by creating a ConfigurationManipulator."""
manipulator = ConfigurationManipulator()
for ext_layer_id, knobs in self.model_info.get_knobs().items():
manipulator.add_parameter(EnumParameter(ext_layer_id, knobs))
return manipulator
def seed_configurations(self):
"""Provide baseline config as seed if model uses seed."""
return [self.bench.get_baseline_config(not is_dev_time)] if self.bench.use_seed else []
def run(self, desired_result, input_, limit):
"""Run a given configuration then return performance and accuracy."""
cfg: ConfigT = desired_result.configuration.data
# get_accuracy gives estimation of mean accuracy and 95% confident accuracy
mean_acc, confident_acc, n_runs = self.tuner_driver.get_accuracy(cfg)
# getConfigCost returns the cost associated with the selected configuration
total_comps, speedup = self.bench.compute_config_cost(cfg)
Result = opentuner.resultsdb.models.Result()
Result.time = total_comps
# Convert QoS to scalar, because opentuner does not support custom comparable datatype
Result.accuracy = confident_acc.to_scalar(relative_to=self.tuner_thres)
# If accuracy is acceptable, write this config
if confident_acc > self.accept_thres:
config_name = self.tuner_driver.make_config_name()
cfg_values = [cfg[layer] for layer in sorted(cfg.keys())]
writing_config = Config(
mean_acc, self.model_info.val_qos, config_name, cfg_values,
n_runs, 95.0, total_comps, speedup
)
self.all_configs.append(writing_config)
msg_logger.debug(
f"Config chosen with accuracy (mean) = {mean_acc}, (95%) = {confident_acc} "
f"and speedup = {speedup}"
)
self.tuner_driver.iter += 1
self.pbar.update()
return Result
def save_final_config(self, configuration):
"""Print final configuration."""
msg_logger.info(f"Final configuration {configuration.data}")
msg_logger.info("Done with Autotuning run")
if __name__ == '__main__':
assert set(networks.keys()).issubset(set(bench_tuner_data.keys()))
for network in ('alexnet2_hpvm',):
bench_: Benchmark = bench_tuner_data[network]
TunerDriver(bench_).tuner_exec()