#!/usr/bin/env python3
import argparse
import os
from pathlib import Path
from typing import List, Union, Optional

PathLike = Union[Path, str]

HPVM_PROJECT_DIR = Path("@LLVM_PROJECT_DIR@") / "tools/hpvm"
LLVM_BUILD_BIN = Path("@LLVM_BUILD_DIR@") / "bin"

# Directories to include
TRT_INCLUDE_DIRS = "@TRT_INCLUDE_DIRS@".split(";")
TRT_LINK_DIRS = [Path(s) for s in "@TRT_LINK_DIRS@".split(";")]
TRT_LINK_LIBS = "@TRT_LINK_LIBS@".split(";")
DIRECT_LINK_LIBS = "@DIRECT_LINK_LIBS@".split(";")

AVAILABLE_PASSES = "@AVAILABLE_PASSES@".split(";")
HPVM_RT_PATH = "@HPVM_RT_PATH@"

COMPILE_FLAGS = ["fno-exceptions", "std=c++11", "O3"]


def compile_hpvm_c(
    input_file: PathLike,
    output_file: PathLike,
    codegen_target: str = "tensor",
    include: List[PathLike] = None,
    working_dir: PathLike = None,
    conf_file: PathLike = None,
):
    from subprocess import check_output

    codegen_functions = {
        "tensor": lambda i, o: opt_codegen_tensor(i, o, conf_file),
        "cudnn": opt_codegen_cudnn
    }
    codegen_f = codegen_functions[codegen_target]
    working_dir = Path(working_dir or ".")
    if not working_dir.is_dir():
        os.makedirs(working_dir)
    name_stem = Path(input_file).stem

    ll_file = working_dir / f"{name_stem}.ll"
    hpvm_ll_file = working_dir / f"{name_stem}.hpvm.ll"
    llvm_ll_file = working_dir / f"{name_stem}.llvm.ll"
    hpvm_rt_linked_file = working_dir / f"{name_stem}.linked.bc"
    commands = [
        hpvm_c_to_ll(input_file, ll_file, extra_includes=include),
        opt_codegen_hpvm(ll_file, hpvm_ll_file),
        codegen_f(hpvm_ll_file, llvm_ll_file),
        link_hpvm_rt(llvm_ll_file, hpvm_rt_linked_file),
        link_binary(hpvm_rt_linked_file, output_file),
    ]
    for command in commands:
        print(" ".join(command))
        check_output(command)


def hpvm_c_to_ll(
    src_file: PathLike,
    target_file: PathLike,
    extra_includes: Optional[List[PathLike]] = None,
    flags: List[str] = None,
) -> List[str]:
    extra_includes = extra_includes or []
    includes = [f"-I{path}" for path in TRT_INCLUDE_DIRS + extra_includes]
    flags = [f"-{flg}" for flg in (flags or []) + COMPILE_FLAGS]
    return [
        str(LLVM_BUILD_BIN / "clang++"), *includes, *flags, "-emit-llvm", "-S",
        str(src_file), "-o", str(target_file)
    ]


def opt_codegen_hpvm(src_file: PathLike, target_file: PathLike) -> List[str]:
    return _run_opt(src_file, target_file, ["LLVMGenHPVM"], ["genhpvm", "globaldce"])


def opt_codegen_cudnn(src_file: PathLike, target_file: PathLike) -> List[str]:
    passes = [
        "LLVMBuildDFG", "LLVMInPlaceDFGAnalysis",
        "LLVMDFG2LLVM_CUDNN", "LLVMDFG2LLVM_CPU",
        "LLVMFuseHPVMTensorNodes", "LLVMClearDFG", "LLVMGenHPVM"
    ]
    flags = [
        "buildDFG", "inplace", "hpvm-fuse",
        "dfg2llvm-cudnn", "dfg2llvm-cpu", "clearDFG"
    ]
    return _run_opt(src_file, target_file, passes, flags)


def opt_codegen_tensor(
    src_file: PathLike, target_file: PathLike, conf_file: PathLike
):
    passes = [
        "LLVMBuildDFG", "LLVMInPlaceDFGAnalysis",
        "LLVMDFG2LLVM_WrapperAPI", "LLVMDFG2LLVM_CPU",
        "LLVMFuseHPVMTensorNodes", "LLVMClearDFG", "LLVMGenHPVM"
    ]
    flags = [
        "buildDFG", "inplace", "hpvm-fuse",
        "dfg2llvm-wrapperapi",
        f"configuration-inputs-filename={conf_file}",
        "dfg2llvm-cpu", "clearDFG",
    ]
    return _run_opt(src_file, target_file, passes, flags)


def link_hpvm_rt(src_file: PathLike, target_file: PathLike) -> List[str]:
    return [str(LLVM_BUILD_BIN / "llvm-link"), str(src_file), HPVM_RT_PATH, "-o", str(target_file)]


def link_binary(src_file: PathLike, target_file: PathLike) -> List[str]:
    def drop_suffix(libname: str):
        import re

        match = re.match(r"lib(.*)\.so", libname)
        return libname if match is None else match.group(1)

    link_dirs, link_libnames = [], []
    for lib in DIRECT_LINK_LIBS:
        lib = Path(lib)
        link_dirs.append(lib.parent)
        link_libnames.append(drop_suffix(lib.name))
    link_dirs += TRT_LINK_DIRS
    link_libnames += TRT_LINK_LIBS

    linker_dir_flags = []
    for path in link_dirs:
        linker_dir_flags.extend([f"-L{path}", f"-Wl,-rpath={path}"])
    linker_lib_flags = [f"-l{drop_suffix(lib)}" for lib in link_libnames]
    return [
        str(LLVM_BUILD_BIN / "clang++"), str(src_file),
        "-o", str(target_file), *linker_dir_flags, *linker_lib_flags
    ]


def _run_opt(
    src_file: PathLike,
    target_file: PathLike,
    pass_names: List[str],
    pass_flags: List[str],
) -> List[str]:
    unavailable = set(pass_names) - set(AVAILABLE_PASSES)
    if unavailable:
        raise ValueError(f"Passes {unavailable} are unavailable from CMake")
    load_passes_strs = [s for pass_ in pass_names for s in ["-load", f"{pass_}.so"]]
    pass_flags_strs = [f"-{flag}" for flag in pass_flags]
    return [
        str(LLVM_BUILD_BIN / "opt"), *load_passes_strs, *pass_flags_strs,
        "-S", str(src_file), "-o", str(target_file)
    ]


def parse_args():
    parser = argparse.ArgumentParser("hpvm-clang")
    parser.add_argument("input_file", type=Path, help="HPVM-C code to compile")
    parser.add_argument("output_file", type=Path, help="Path to generate binary to")
    parser.add_argument(
        "-t",
        "--codegen-target",
        type=str,
        required=True,
        choices=["tensor", "cudnn"],
        help="Backend to use",
    )
    parser.add_argument(
        "-d", "--working-dir", type=Path, help="Directory to generate temp files in"
    )
    parser.add_argument(
        "--conf-file", type=Path,
        help="File to approximation configurations; required for 'tensor' target"
    )
    parser.add_argument(
        "-I", "--include", type=Path, action="append",
        help="Additional include directories to use"
    )

    args = parser.parse_args()
    if args.codegen_target == "tensor":
        if args.conf_file is None:
            parser.error('Codegen target "tensor" requires --conf-file argument')
    return args


def main():
    compile_hpvm_c(**vars(parse_args()))


if __name__ == "__main__":
    main()