#!/usr/bin/env python3
import argparse
import os
from pathlib import Path
from typing import List, Union, Optional

PathLike = Union[Path, str]

HPVM_PROJECT_DIR = Path("@LLVM_PROJECT_DIR@") / "tools/hpvm"
LLVM_BUILD_BIN = Path("@LLVM_BUILD_DIR@") / "bin"

# Directories to include
INCLUDE_DIRS = "@INCLUDE_DIRS@".split(";")
# TODO: This dependency comes from hpvm-rt. Should have CMake inject this instead.
LINK_LIBS = ["pthread"]
DIRECT_LINK_LIBS = "@DIRECT_LINK_LIBS@".split(";")
HPVM_USE_OPENCL = int("@HPVM_USE_OPENCL@")
HPVM_HAS_TRT = int("@HPVM_HAS_TRT@")

AVAILABLE_PASSES = "@AVAILABLE_PASSES@".split(";")
HPVM_RT_PATH = "@HPVM_RT_PATH@"


def compile_hpvm_c(
    hpvm_src: PathLike,
    output_file: PathLike,
    tensor_target: Optional[str],
    opencl: bool,
    link_bitcode: List[PathLike] = None,
    include: List[PathLike] = None,
    macro: List[str] = None,
    flags: List[str] = None,
    optim_level: str = "1",  # -O1
    is_cpp: bool = True,  # otherwise is C
    std: str = None,  # language std (-std=c++11)
    link_dirs: List[PathLike] = None,
    link_libs: List[str] = None,
    working_dir: PathLike = None,
    conf_file: PathLike = None,
    verbose: bool = False,
):
    from subprocess import check_output

    flags = (flags or [])
    passes = ["LLVMBuildDFG"]
    pass_flags = ["buildDFG"]
    if tensor_target == "tensor":
        if conf_file is None:
            raise ValueError("conf_file must be defined when tensor_target=='tensor'.")
        passes += ["LLVMInPlaceDFGAnalysis", "LLVMFuseHPVMTensorNodes", "LLVMDFG2LLVM_WrapperAPI"]
        pass_flags += [
            "inplace", "hpvm-fuse", "dfg2llvm-wrapperapi",
            f"configuration-inputs-filename={conf_file}"
        ]
      
    elif tensor_target == "nvdla":
        if conf_file is None:
            raise ValueError("conf_file must be defined when tensor_target=='tensor'.")

        passes += ["LLVMHPVM2NVDLAPass"]
        pass_flags += [
            "hpvm-nvdla",
            "cprecision=fp16",
            "calib-table=calib.txt"
        ]
        
    elif tensor_target == "cudnn":
        passes += ["LLVMInPlaceDFGAnalysis", "LLVMDFG2LLVM_CUDNN"]
        pass_flags += ["inplace", "dfg2llvm-cudnn"]
    elif tensor_target is None:
        passes += ["LLVMLocalMem"]
        pass_flags += ["localmem"]
    else:
        raise ValueError(f"Tensor target {tensor_target} not recognized")
    if opencl:
        passes += ["LLVMDFG2LLVM_OpenCL"]
        pass_flags += ["dfg2llvm-opencl"]
    passes += ["LLVMDFG2LLVM_CPU", "LLVMClearDFG"]
    pass_flags += ["dfg2llvm-cpu", "clearDFG"]

    working_dir = Path(working_dir or ".")
    if not working_dir.is_dir():
        os.makedirs(working_dir)

    # All commands for compiling the main hpvm_c file
    name_stem = Path(hpvm_src).stem
    ll_file = working_dir / f"{name_stem}.ll"
    hpvm_ll_file = working_dir / f"{name_stem}.hpvm.ll"
    llvm_ll_file = working_dir / f"{name_stem}.llvm.ll"
    hpvm_rt_linked_file = working_dir / f"{name_stem}.linked.bc"
    link_bitcode_ = [Path(bc) for bc in (link_bitcode or [])]
    commands = [
        hpvm_c_to_ll(hpvm_src, ll_file, include, macro, flags, optim_level, is_cpp, std),
        opt_codegen_hpvm(ll_file, hpvm_ll_file),
        _run_opt(hpvm_ll_file, llvm_ll_file, passes, pass_flags),
        link_hpvm_rt(link_bitcode_ + [llvm_ll_file], hpvm_rt_linked_file),
    ]
    commands.append(
        link_binary(hpvm_rt_linked_file, output_file, link_dirs, link_libs)
    )
    for command in commands:
        if verbose:
            print(" ".join(command))
        check_output(command)


def hpvm_c_to_ll(
    src_file: PathLike,
    target_file: PathLike,
    extra_includes: List[PathLike] = None,
    macros: List[str] = None,
    flags: List[str] = None,
    optim_level: str = "1",  # -O1
    is_cpp: bool = True,  # otherwise is C
    std: str = None,  # --std=c++11
) -> List[str]:
    includes = (extra_includes or []) + INCLUDE_DIRS
    includes_s = [f"-I{path}" for path in includes]
    macros = [f"-D{macro}" for macro in (macros or [])]
    flags = [f"-f{flg}" for flg in (flags or [])]
    if std:
        flags.append(f"-std={std}")
    clang = "clang++" if is_cpp else "clang"
    return [
        str(LLVM_BUILD_BIN / clang), "-O0", "-Xclang", "-disable-O0-optnone", ####   f"-O{optim_level}",
        *includes_s, *flags, *macros,
        "-emit-llvm", "-S", str(src_file), "-o", str(target_file)
    ]


def opt_codegen_hpvm(src_file: PathLike, target_file: PathLike) -> List[str]:
    return _run_opt(src_file, target_file, ["LLVMGenHPVM"], ["genhpvm", "globaldce"])


def link_hpvm_rt(bitcodes: List[PathLike], target_file: PathLike) -> List[str]:
    bitcodes_s = [str(bc) for bc in bitcodes]
    return [str(LLVM_BUILD_BIN / "llvm-link"), *bitcodes_s, HPVM_RT_PATH, "-S", "-o", str(target_file)]


def link_binary(
    src_file: PathLike,
    target_file: PathLike,
    extra_link_dirs: List[PathLike] = None,
    extra_link_libs: List[str] = None
) -> List[str]:
    link_dirs, link_libnames = _parse_direct_link_libs()
    link_dirs += (extra_link_dirs or [])
    link_libstems = (extra_link_libs or []) + LINK_LIBS

    linker_dir_flags = []
    for path in link_dirs:
        linker_dir_flags.extend([f"-L{path}", f"-Wl,-rpath={path}"])
    linker_lib_flags = (
        [f"-l{lib}" for lib in link_libstems] +
        [f"-l:{libname}" for libname in link_libnames]
    )
    return [
        str(LLVM_BUILD_BIN / "clang++"), str(src_file),
        "-o", str(target_file), *linker_dir_flags, *linker_lib_flags
    ]


def _parse_direct_link_libs():
    link_dirs, link_libnames = [], []
    for lib in DIRECT_LINK_LIBS:
        lib = Path(lib)
        link_dirs.append(lib.parent)
        link_libnames.append(lib.name)
    return link_dirs, link_libnames


def _run_opt(
    src_file: PathLike,
    target_file: PathLike,
    pass_names: List[str],
    pass_flags: List[str],
) -> List[str]:
    unavailable = set(pass_names) - set(AVAILABLE_PASSES)
    if unavailable:
        raise ValueError(f"Passes {unavailable} are unavailable for this compilation.")
    load_passes_strs = [s for pass_ in pass_names for s in ["-load", f"{pass_}.so"]]
    pass_flags_strs = [f"-{flag}" for flag in pass_flags]
    return [
        str(LLVM_BUILD_BIN / "opt"),
        "-mem2reg",
        *load_passes_strs, *pass_flags_strs,
        "-S", str(src_file), "-o", str(target_file)
    ]


def parse_args():
    parser = argparse.ArgumentParser("hpvm-clang")
    parser.add_argument(
        "hpvm_src", type=Path,
        help="""HPVM-C code to compile.
HPVM-C code must be single file, but additional bitcode file can be linked together.
See option -b for that."""
    )
    parser.add_argument("output_file", type=Path, help="Path to generate binary to")
    parser.add_argument(
        "-x", type=str, metavar="language", default="c++",
        help="Treat input file as having type <language>",
    )
    parser.add_argument(
        "-b",
        "--link-bitcode",
        type=Path,
        nargs="+",
        help="Additional bitcode (.ll/.bc) files to link to",
    )
    parser.add_argument(
        "-t",
        "--tensor-target",
        type=str,
        choices=["tensor", "cudnn", "nvdla"],
        help="Backend to use for tensor operators",
    )
    parser.add_argument(
        "--conf-file", type=Path,
        help="File to approximation configurations; required for tensor target 'tensor'"
    )
    parser.add_argument(
        "--opencl",
        action="store_true",
        help="Use OpenCL support. Requires HPVM built with OpenCL",
    )
    parser.add_argument(
        "-d", "--working-dir", type=Path, help="Directory to generate temp files in"
    )

    # Relaying arguments for clang++ (source -> bitcode stage)
    parser.add_argument(
        "-I", "--include", type=Path, action="append", metavar="dir",
        help="[clang emit-llvm] Add directory to include search path"
    )
    parser.add_argument(
        "-D", type=str, action="append", metavar="<macro>=<value>",
        help="[clang emit-llvm] Define macro"
    )
    parser.add_argument(
        "-f", type=str, action="append", metavar="flag",
        help="[clang emit-llvm] clang++ flags (such as -ffastmath)"
    )
    parser.add_argument(
        "-O", type=str, default="1", metavar="level",
        help="[clang emit-llvm] Optimization level. Note that default is -O1."
    )
    parser.add_argument(
        "--std", type=str,
        help="[clang emit-llvm] Language standard to compile for. Double dashes (--std, not -std)."
    )

    # Relaying arguments for clang++ (linking stage)
    parser.add_argument(
        "-L", type=Path, action="append", metavar="dir",
        help="[clang linker] Add directory to library search path"
    )
    parser.add_argument(
        "-l", type=str, action="append", metavar="name",
        help="[clang linker] Link library (such as -lpthread)"
    )

    parser.add_argument(
        "-v", "--verbose", action="store_true", help="Print out all clang/opt/llvm-link commands used"
    )

    args = parser.parse_args()
    if args.tensor_target == "tensor":
        if args.conf_file is None:
            parser.error('Tensor target "tensor" requires --conf-file argument')
        
    if args.x == "c":
        args.is_cpp = False
    elif args.x == "c++":
        args.is_cpp = True
    else:
        parser.error(f"Language mode {args.x} not supported yet -- only c or c++")
    if not HPVM_USE_OPENCL and args.opencl:
        parser.error(f"OpenCL is disabled for this build of HPVM.")
    if not HPVM_HAS_TRT and args.tensor_target != "nvdla":
        parser.error(
            "Tensor domain support is disabled for this build of HPVM; "
            "please check your CMake warnings during compilation."
        )
    return args


def main():
    args = vars(parse_args())
    args["macro"] = args.pop("D")
    args["flags"] = args.pop("f")
    args["optim_level"] = args.pop("O")
    args["link_dirs"] = args.pop("L")
    args["link_libs"] = args.pop("l")
    args.pop("x")
    compile_hpvm_c(**args)


if __name__ == "__main__":
    main()