from pathlib import Path
from subprocess import CalledProcessError, PIPE
from typing import Iterable, List, Tuple, Union
from dataclasses import dataclass
from tqdm import trange

PathLike = Union[Path, str]
conf_opening, conf_closing = "+++++", "-----"


def profile_configs(
    binary_path: PathLike,
    config_path: PathLike,
    output_config_path: PathLike,
    profile_filename: str = "profile_info.txt",
    qos_filename: str = "final_accuracy",
):
    """
    Profile an HPVM configuration file with an HPVM binary.
    The configuration file must have the baseline as the first configuration.

    binary_path: Union[Path, str]
        Path to binary to be executed in profiling.
    config_path: Union[Path, str]
        Path to config file (HPVM configuration format)
        with configs to enumerate for profiling.
    output_config_path: Union[Path, str]
        Path where the output configs are written.
        The output config file has the same configs as the input `config_path` file,
        but the performance and energy readings are updated.
    profile_filename: str
        Name of profile file generated by the binary (in current directory).
        This defaults to "profile_info.txt" and should not be changed for HPVM binaries.
    qos_filename: str
        Name of QoS file generated by the binary (in current directory).
        It contains a single float number as the QoS of this run.
        This defaults to "final_accuracy" and should not be changed for HPVM binaries.
    """

    from subprocess import check_call
    from tempfile import NamedTemporaryFile

    # Read first line ("the float") and configs in config file
    header, configs = read_hpvm_configs(Path(config_path))
    if not configs:
        raise ValueError("Config file with no configs is unsupported.")
    temp_file = NamedTemporaryFile("w")
    baseline_time, baseline_acc = None, None
    for idx in trange(len(configs), desc="Configs profiled"):
        config = configs[idx]
        # Write config to temp config file
        write_hpvm_config(header, [config], Path(temp_file.name))
        # Run binary_path binary,
        # which generates `profile_filename` and `qos_filename` file in cwd.
        try:
            check_call([str(binary_path), "-c", str(temp_file.name)], stdout=PIPE)
        except CalledProcessError as e:
            print("Output from the program:")
            print(e.output)
            raise e
        # Read these two files for time and QoS info.
        time = _read_profile_file(Path(profile_filename))
        acc = _read_qos_file(Path(qos_filename))
        if idx == 0:
            baseline_time, baseline_acc = time, acc
            continue
        assert baseline_time is not None and baseline_acc is not None
        speedup = baseline_time / time
        config.update_profile_results(speedup, acc, baseline_acc)
    write_hpvm_config(header, configs, Path(output_config_path))
    temp_file.close()


def plot_hpvm_configs(
    config_path: PathLike,
    save_to: PathLike = None,
    show_qos_loss: bool = True,
    **fig_kwargs,
):
    """
    Plot the QoS-speedup information in an HPVM configuration file.
    It is recommended to profile the config file first (using `profile_configs`)
    to obtain real speedup numbers.
    This function creates a `matplotlib.pyplot.Figure`, plots on it, and returns it.

    config_path: Union[Path, str]
        Path to the config file (HPVM configuration format).
    save_to: Union[Path, str]
        File to save figure into. Default is None: don't save figure (just return it).
    show_qos_loss: bool
        Show the loss of QoS on x axis of the figure. Defaults to True.
        If False, will use (absolute) QoS instead of QoS loss.
    fig_kwargs:
        Arguments to pass to `plt.subplots`.
    """

    import numpy as np
    import matplotlib.pyplot as plt

    _, configs = read_hpvm_configs(config_path)
    get_qos = lambda c: c.qos_loss if show_qos_loss else c.qos
    qos_speedup = np.array([(get_qos(c), c.speedup) for c in configs])
    qoses, speedups = qos_speedup.T
    fig, ax = plt.subplots(**fig_kwargs)
    ax.scatter(qoses, speedups)
    ax.xlabel("QoS Loss")
    ax.ylabel("Speedup (X)")
    if save_to:
        fig.savefig(save_to, dpi=300)
    return fig

@dataclass
class Config:
    conf_name: str
    speedup: float
    energy: float
    qos: float
    qos_loss: float
    # We don't care about the information in this part, and we don't parse this.
    config_body: List[str]

    def update_profile_results(self, speedup: float, qos: float, base_qos: float):
        recorded_base_qos = self.qos + self.qos_loss
        if abs(recorded_base_qos - base_qos) > 1e-3:
            raise ValueError(
                f"Baseline QoS mismatch. Original: {recorded_base_qos}, measured: {base_qos}"
            )
        self.speedup = speedup
        self.qos = qos
        self.qos_loss = base_qos - qos

    def __repr__(self) -> str:
        header_fields = [
            self.conf_name,
            self.speedup,
            self.energy,
            self.qos,
            self.qos_loss,
        ]
        header = " ".join(str(field) for field in header_fields)
        lines = [conf_opening, header, *self.config_body, conf_closing]
        return "\n".join(lines)

    __str__ = __repr__


def read_hpvm_configs(config_file: PathLike) -> Tuple[str, List[Config]]:
    # def read_hpvm_configs(config_file, config_num, temp_file):
    ret_configs = []
    with open(config_file) as f:
        text = f.read()
    # There's 1 float sitting on the first line of config file.
    # We don't use it, but want to keep that intact.
    header, *configs = text.split(conf_opening)
    header = header.strip()
    for config_text in configs:
        config_text = config_text.replace(conf_closing, "").strip()
        config_header, *config_body = config_text.splitlines()
        conf_name, *number_fields = config_header.split(" ")
        speedup, energy, qos, qos_drop = [float(s) for s in number_fields]
        ret_configs.append(
            Config(conf_name, speedup, energy, qos, qos_drop, config_body)
        )
    return header, ret_configs


def write_hpvm_config(header: str, configs: Iterable[Config], to_file: PathLike):
    text_segs = [header] + [str(config) for config in configs]
    with open(to_file, "w") as f:
        f.write("\n".join(text_segs))


def _read_profile_file(profile_file_path: Path):
    with profile_file_path.open() as f:
        target_lines = [line.strip() for line in f if "Total Time" in line]
    if len(target_lines) != 1:
        raise RuntimeError(f"Profile {profile_file_path} malformed")
    (target_line,) = target_lines
    return float(target_line.split()[3])


def _read_qos_file(qos_file_path: Path):
    with qos_file_path.open() as f:
        return float(f.read().strip())