Skip to content
Snippets Groups Projects
run_tuner.py 12.95 KiB
#!/usr/bin/env python
#
# Development-time Tuner with Algorithmic Approximations:
# Approximations: Perforation, Sampling with varying knobs for rate, skip offset
import copy
import logging
import os
import shutil
import time
from pathlib import Path
from typing import List, Tuple

import numpy as np
import opentuner
from opentuner import ConfigurationManipulator, EnumParameter, MeasurementInterface
from opentuner.measurement.inputmanager import FixedInputManager
from opentuner.search.objective import ThresholdAccuracyMinimizeTime
from opentuner.tuningrunmain import TuningRunMain
from torch.nn import Module
from tqdm import tqdm

from exp import Benchmark, ConfigMeasurer, ExpState, TuningTime, batch_id, bench_tuner_data, is_dev_time
from models import get_all_output, networks, QoS
from toolkit import ConfigT
from toolkit.estimators import WeightedLinearQoSEstimator
from utils import Config, config, reapply_last_config

msg_logger = logging.getLogger(__name__)
use_proxy = False
n_promise_valid_runs = 30
confidence_level = 0.95


def init_proxy(ni: ConfigMeasurer, pickle_path: Path):
    def acc_crit(inputs_):
        return ni.get_qos(inputs_, ni.val_loader)

    def threshold_eval(inputs_):
        accs = np.array([acc_crit(x) for x in inputs_])
        return ni.val_qos - accs.mean() < 3.0

    def run_model(net: Module):
        return get_all_output(net, ni.val_loader)

    return WeightedLinearQoSEstimator(
        ni.nas, run_model, acc_crit, threshold_eval, confidence_level, storage=pickle_path
    )


class Timer:
    def __init__(self, timer_state: TuningTime, timer_name: str):
        self.timer_state = timer_state
        self.name = timer_name
        self.start = None

    def __enter__(self):
        self.start = time.time()
        return self

    def __exit__(self, *args):
        end = time.time()
        interval = end - self.start
        self.timer_state.add_timer(self.name, interval)


class TunerDriver:
    def __init__(self, bench: Benchmark):
        self.bench = bench
        msg_logger.info(f"Tuning for model {self.bench.model_name}")
        # Initialize folder.
        self._init_folder(bench)
        # Take a snapshot of current code.
        self.take_code_snapshot()
        # Initialize network information and qos thresholds
        self.net_info = ConfigMeasurer.init_from_bench(self.bench)
        qoses = self.net_info.val_qos, self.net_info.test_qos
        qos_type = self.net_info.val_qos.__class__
        self.tuner_thres = qos_type.suggested_tuner_thresholds(self.net_info.val_qos)
        self.val_thres = qos_type.suggested_val_threshold(self.net_info.val_qos)
        self.test_thres = qos_type.suggested_test_threshold(self.net_info.test_qos)
        # Tuner states.
        self.states = ExpState(bench, qos_type, qoses)
        # Current # of iteration. `ProxyTuner` will use this.
        self.run_id, self.iter = 0, 0
        # Initialize proxy.
        if use_proxy:
            self.proxy = init_proxy(self.net_info, self.bench.result_dir / 'proxy.pkl')
        else:
            self.proxy = None

    @staticmethod
    def _init_folder(bench: Benchmark):
        def remove_file_or_folder(path: Path):
            if path.is_dir():
                shutil.rmtree(child)
            elif path.is_file():
                path.unlink()  # Removes file despite the surprising name

        pickle_path = bench.result_dir / 'proxy.pkl'
        # Remove everything in result folder except pickle file
        if bench.result_dir.is_dir():
            msg_logger.warning(f"!Cleaning existing result dir = {bench.result_dir}")
            for child in bench.result_dir.glob('*'):
                if child == pickle_path:
                    continue
                msg_logger.info(f"  !Removing {child}")
                remove_file_or_folder(child)
        # Create result folder if it doesn't exist
        if not bench.result_dir.is_dir():
            msg_logger.info(f"Creating output directory = {bench.result_dir}")
            os.makedirs(bench.result_dir)

    def get_default_args(self):
        args = opentuner.default_argparser().parse_args()
        args.database = f"opentuner.db/{batch_id}.db"
        args.test_limit = self.bench.autotuner_runs
        parent = Path(args.database).parent
        if not parent.is_dir():
            os.makedirs(parent, exist_ok=True)
        return args

    def tuner_exec(self):
        # Get default opentuner args
        args = self.get_default_args()
        # Start tuning for each threshold
        for i, thres in enumerate(self.tuner_thres):
            with Timer(self.states.timers, f"tuning_{i}"):
                msg_logger.info(
                    f"Tuning goal: qos >= {thres}; keeping configs with qos >= {self.val_thres}"
                )
                tuner = ProxyTuner(args, self, thres, self.val_thres)
                # TuningRunMain.__init__ initializes its own logger, so we'll reapply our settings.
                tuning_main = TuningRunMain(tuner, args)
                reapply_last_config()
                # Unleash the tuner!
                tuning_main.main()
                # Remove tuner progress bar
                tuner.pbar.close()
                self.run_id += 1
                self.iter = 0
        # Postprocess configs
        self.process_configs()

    def calibrate_write_configs(self, configs: List[Config], is_test_set: bool):
        write_to = self.states.tested_configs if is_test_set else self.states.validated_configs
        gold_acc = self.net_info.test_qos if is_test_set else self.net_info.val_qos
        for cfg in tqdm(configs, leave=False):
            cfg = copy.deepcopy(cfg)
            cfg: Config
            flags = {k: v for k, v in enumerate(cfg.flags)}
            measured_acc, confidence = self.net_info.actual_measure(
                flags, cfg.total_runs, is_test_set, threshold=self.val_thres
            )
            prev_acc = cfg.avg_qos
            cfg.update_acc(measured_acc, confidence, gold_acc)
            new_acc = cfg.avg_qos
            msg_logger.debug(f"{prev_acc} (mean) -> {new_acc} (mean)")
            write_to.append(cfg)
        write_to.finalize_dump()

    @staticmethod
    def filter_configs(
            validation: List[Config], test: List[Config],
            vali_threshold: QoS, test_threshold: QoS
    ) -> Tuple[List[Config], List[Config]]:
        # Filter validation and test set by their respective thresholds
        filtered_validation = [
            c for c in validation if c.avg_loss <= vali_threshold
        ]
        filtered_test = [
            c for c in test if c.avg_loss <= test_threshold
        ]
        # Test configs also need to be a subset of validation configs.
        name_to_filtered = {x.fname: x for x in filtered_test}
        intersect_names = set(list(name_to_filtered.keys())).intersection(
            set((x.fname for x in filtered_validation))
        )
        filtered_test_ = [name_to_filtered[fname] for fname in intersect_names]
        return filtered_validation, filtered_test_

    def process_configs(self):
        # Finalize all configs because tuning is done.
        # (this may not do anything now but will in the future)
        self.states.all_configs.finalize_dump()
        all_configs = self.states.all_configs.configs
        # Pre-filter configs by a wide pareto margin
        filtered_configs = config.is_pareto_efficient(all_configs, ratio=0.05, n_min=50, n_max=50)
        msg_logger.info(f"Prefilter yields {len(filtered_configs)} configs from {len(all_configs)}")
        self.states.filtered_configs.finalize_dump(with_configs=filtered_configs)
        # Calibrate prefiltered configs (validation step)
        with Timer(self.states.timers, "validate"):
            self.calibrate_write_configs(filtered_configs, is_test_set=False)
            validated_configs = self.states.validated_configs.configs
        # Calibrate prefiltered configs on test set (test step)
        with Timer(self.states.timers, "test"):
            self.calibrate_write_configs(filtered_configs, is_test_set=True)
            tested_configs = self.states.tested_configs.configs
        # Filter valid and test set configs by thresholds
        valid_configs, test_configs = self.filter_configs(
            validated_configs, tested_configs, self.val_thres, self.test_thres
        )
        self.states.valid_configs.finalize_dump(valid_configs)
        self.states.test_configs.finalize_dump(test_configs)
        # Finalize data input and plot everything.
        self.states.finalize_plot()

    def take_code_snapshot(self):
        import git
        msg_logger.info(f"Taking git snapshot")
        ref_dir = self.bench.result_dir / "references"
        os.mkdir(ref_dir)
        # Write current git commit (SHA id)
        repo = git.Repo(search_parent_directories=True)
        sha = repo.head.object.hexsha
        msg_logger.info(f"Current code is at commit {sha}")
        with (ref_dir / 'git_commit.txt').open('w') as f:
            f.write(sha)
        # Also put all outstanding code change in a diff file.
        # This way changes in all git-tracked files are captured.
        t = repo.head.commit.tree
        with (ref_dir / 'diff.txt').open('w') as f:
            f.write(repo.git.diff(t))

    def make_config_name(self) -> str:
        return f"{self.bench.model_name}_{self.run_id}_{self.iter}"

    def get_accuracy(self, cfg: ConfigT) -> Tuple[QoS, QoS, int]:
        has_promise_flags = set(cfg.values()).intersection(set(range(1, 7 + 1)))
        config_validation_runs = n_promise_valid_runs if has_promise_flags else 1
        if use_proxy:
            mean_acc, confidence_acc = self.net_info.proxy_estimate(cfg, self.proxy)
            assert has_promise_flags or (mean_acc == confidence_acc)
        else:
            mean_acc, _ = self.net_info.actual_measure(cfg, 1, is_test_set=False)
            confidence_acc = mean_acc
        return mean_acc, confidence_acc, config_validation_runs


class ProxyTuner(MeasurementInterface):
    def __init__(self, args, driver: TunerDriver, tuner_thres: QoS, accept_thres: QoS):
        self.tuner_driver = driver
        self.model_info = driver.net_info
        self.bench = driver.bench
        self.tuner_thres = tuner_thres
        self.all_configs = driver.states.all_configs
        self.pbar = tqdm(total=args.test_limit, leave=False)
        objective = ThresholdAccuracyMinimizeTime(tuner_thres.to_scalar())
        input_manager = FixedInputManager(size=driver.bench.get_n_layers())
        super(ProxyTuner, self).__init__(
            args, program_name=self.bench.model_name,
            input_manager=input_manager, objective=objective
        )
        self.accept_thres = accept_thres

    def manipulator(self) -> ConfigurationManipulator:
        """Define the search space by creating a ConfigurationManipulator."""
        manipulator = ConfigurationManipulator()
        for ext_layer_id, knobs in self.model_info.get_knobs().items():
            manipulator.add_parameter(EnumParameter(ext_layer_id, knobs))
        return manipulator

    def seed_configurations(self):
        """Provide baseline config as seed if model uses seed."""
        return [self.bench.get_baseline_config(not is_dev_time)] if self.bench.use_seed else []

    def run(self, desired_result, input_, limit):
        """Run a given configuration then return performance and accuracy."""
        cfg: ConfigT = desired_result.configuration.data
        # get_accuracy gives estimation of mean accuracy and 95% confident accuracy
        mean_acc, confident_acc, n_runs = self.tuner_driver.get_accuracy(cfg)
        # getConfigCost returns the cost associated with the selected configuration
        total_comps, speedup = self.bench.compute_config_cost(cfg)
        Result = opentuner.resultsdb.models.Result()
        Result.time = total_comps
        # Convert QoS to scalar, because opentuner does not support custom comparable datatype
        Result.accuracy = confident_acc.to_scalar(relative_to=self.tuner_thres)

        # If accuracy is acceptable, write this config
        if confident_acc > self.accept_thres:
            config_name = self.tuner_driver.make_config_name()
            cfg_values = [cfg[layer] for layer in sorted(cfg.keys())]
            writing_config = Config(
                mean_acc, self.model_info.val_qos, config_name, cfg_values,
                n_runs, 95.0, total_comps, speedup
            )
            self.all_configs.append(writing_config)
            msg_logger.debug(
                f"Config chosen with accuracy (mean) = {mean_acc}, (95%) = {confident_acc} "
                f"and speedup = {speedup}"
            )
        self.tuner_driver.iter += 1
        self.pbar.update()
        return Result

    def save_final_config(self, configuration):
        """Print final configuration."""
        msg_logger.info(f"Final configuration {configuration.data}")
        msg_logger.info("Done with Autotuning run")


if __name__ == '__main__':
    assert set(networks.keys()).issubset(set(bench_tuner_data.keys()))
    for network in ('alexnet2_hpvm',):
        bench_: Benchmark = bench_tuner_data[network]
        TunerDriver(bench_).tuner_exec()