from pathlib import Path
from subprocess import PIPE, CalledProcessError
from typing import Iterable, List, Tuple, Union

import matplotlib.pyplot as plt

PathLike = Union[Path, str]
conf_opening, conf_closing = "+++++", "-----"


def profile_config_file(
    binary_path: PathLike,
    config_path: PathLike,
    output_config_path: PathLike,
    progress_bar: bool = True,
    profile_filename: str = "profile_info.txt",
    qos_filename: str = "final_accuracy",
) -> None:
    r"""Profile an HPVM configuration file with an HPVM binary,
    and write the updated configuration file to a given location.
    The configuration file must have the baseline as the first configuration.

    :param binary_path: Path to binary to be executed in profiling.
    :param config_path: Path to config file (HPVM configuration format)
        with configs to enumerate for profiling.
    :param output_config_path: Path where the output configs are written.
        The output config file has the same configs as the input `config_path` file,
        but the performance and energy readings are updated.
    :param progress_bar: If `True`, show a progress bar for number of configs already profiled.
    :param profile_filename: Name of profile file generated by the binary (in current directory).
        This defaults to "profile_info.txt" and should not be changed for HPVM binaries.
    :param qos_filename: Name of QoS file generated by the binary (in current directory).
        It contains a single float number as the QoS of this run.
        This defaults to "final_accuracy" and should not be changed for HPVM binaries.
    """
    # Read first line ("the float") and configs in config file
    header, configs = read_hpvm_configs(Path(config_path))
    if not configs:
        raise ValueError("Config file with no configs is unsupported.")
    # Modifies configs in place.
    profile_configs(
        binary_path,
        configs[1:],
        configs[0],
        progress_bar,
        profile_filename,
        qos_filename,
    )
    write_hpvm_configs(header, configs, Path(output_config_path))


def profile_configs(
    binary_path: PathLike,
    configs: Iterable["Config"],
    baseline_config: "Config",
    progress_bar: bool = True,
    profile_filename: str = "profile_info.txt",
    qos_filename: str = "final_accuracy",
) -> None:
    """Profile a sequence of HPVM configs.
    This function modifies argument `configs` in place."""

    from tqdm import tqdm

    baseline_time, baseline_acc = measure_config(binary_path, baseline_config)
    iterable = tqdm(configs, desc="Configs profiled") if progress_bar else configs
    for config in iterable:
        time, acc = measure_config(binary_path, config, profile_filename, qos_filename)
        speedup = baseline_time / time
        config.update_profile_results(speedup, acc, baseline_acc)
    return configs


def measure_config(
    binary_path: PathLike,
    config: "Config",
    profile_filename: str = "profile_info.txt",
    qos_filename: str = "final_accuracy",
):
    from subprocess import check_call
    from tempfile import NamedTemporaryFile
    import os

    temp_file = NamedTemporaryFile("w")
    write_hpvm_configs("0.0", [config], Path(temp_file.name))
    # Run binary_path binary,
    # which generates `profile_filename` and `qos_filename` file in cwd.
    try:
        with open(os.devnull, "w") as f:
            check_call([str(binary_path), "-c", str(temp_file.name)], stdout=f)
    except CalledProcessError as e:
        print("Output from the program:")
        print(e.output)
        raise e
    time = _read_profile_file(Path(profile_filename))
    acc = _read_qos_file(Path(qos_filename))
    temp_file.close()
    return time, acc


def plot_hpvm_configs(
    config_path: PathLike,
    save_to: PathLike = None,
    show_qos_loss: bool = True,
    **fig_kwargs,
) -> plt.Figure:
    """
    Plot the QoS-speedup information in an HPVM configuration file.
    It is recommended to profile the config file first (using `profile_configs`)
    to obtain real speedup numbers.
    This function creates a `matplotlib.pyplot.Figure`, plots on it, and returns it.

    :param config_path: Path to the config file (HPVM configuration format).
    :param save_to: File to save figure into. Default is None: don't save figure (just return it).
    :param show_qos_loss: Show the loss of QoS on x axis of the figure. Defaults to True.
        If False, will use (absolute) QoS instead of QoS loss.
    :param fig_kwargs: Arguments to pass to `plt.subplots`.
    """

    import numpy as np

    _, configs = read_hpvm_configs(config_path)
    get_qos = lambda c: c.qos_loss if show_qos_loss else c.qos
    qos_speedup = np.array([(get_qos(c), c.speedup) for c in configs])
    qoses, speedups = qos_speedup.T
    fig, ax = plt.subplots(**fig_kwargs)
    ax.scatter(qoses, speedups)
    ax.xlabel("QoS Loss")
    ax.ylabel("Speedup (X)")
    if save_to:
        fig.savefig(save_to, dpi=300)
    return fig


class Config:
    def __init__(
        self,
        conf_name: str,
        speedup: float,
        energy: float,
        qos: float,
        qos_loss: float,
        config_body: List[str],
    ):
        self.conf_name = conf_name
        self.speedup = speedup
        self.energy = energy
        self.qos = qos
        self.qos_loss = qos_loss
        # We don't care about the information in this part, and we don't parse this.
        self.config_body = config_body

    def update_profile_results(self, speedup: float, qos: float, base_qos: float):
        recorded_base_qos = self.qos + self.qos_loss
        if abs(recorded_base_qos - base_qos) > 0.025:
            raise ValueError(
                f"Baseline QoS mismatch. Original: {recorded_base_qos}, measured: {base_qos}"
            )
        self.speedup = speedup
        self.qos = qos
        self.qos_loss = base_qos - qos

    def __repr__(self) -> str:
        header_fields = [
            self.conf_name,
            self.speedup,
            self.energy,
            self.qos,
            self.qos_loss,
        ]
        header = " ".join(str(field) for field in header_fields)
        lines = [conf_opening, header, *self.config_body, conf_closing]
        return "\n".join(lines)

    __str__ = __repr__


def read_hpvm_configs(config_file: PathLike) -> Tuple[str, List[Config]]:
    # def read_hpvm_configs(config_file, config_num, temp_file):
    ret_configs = []
    with open(config_file) as f:
        text = f.read()
    # There's 1 float sitting on the first line of config file.
    # We don't use it, but want to keep that intact.
    header, *configs = text.split(conf_opening)
    header = header.strip()
    for config_text in configs:
        config_text = config_text.replace(conf_closing, "").strip()
        config_header, *config_body = config_text.splitlines()
        conf_name, *number_fields = config_header.split(" ")
        speedup, energy, qos, qos_drop = [float(s) for s in number_fields]
        ret_configs.append(
            Config(conf_name, speedup, energy, qos, qos_drop, config_body)
        )
    return header, ret_configs


def write_hpvm_configs(header: str, configs: Iterable[Config], to_file: PathLike):
    text_segs = [header] + [str(config) for config in configs]
    with open(to_file, "w") as f:
        f.write("\n".join(text_segs))
        f.flush()


def _read_profile_file(profile_file_path: Path):
    with profile_file_path.open() as f:
        target_lines = [line.strip() for line in f if "Total Time" in line]
    if len(target_lines) != 1:
        raise RuntimeError(f"Profile {profile_file_path} malformed")
    (target_line,) = target_lines
    return float(target_line.split()[3])


def _read_qos_file(qos_file_path: Path):
    with qos_file_path.open() as f:
        return float(f.read().strip())