diff --git a/hpvm/tools/py-approxhpvm/main.py.in b/hpvm/tools/py-approxhpvm/main.py.in index 2f007a5317d53d45f4e89e852b57648748f7c9ab..5bdc398b410b19b5e0d8c90880edf7bd3d744e07 100644 --- a/hpvm/tools/py-approxhpvm/main.py.in +++ b/hpvm/tools/py-approxhpvm/main.py.in @@ -29,7 +29,8 @@ def compile_hpvm_c( macro: List[str] = None, flags: List[str] = None, optim_level: str = "0", # -O0 - std: str = "c++11", # --std=c++11 + is_cpp: bool = True, # otherwise is C + std: str = None, # language std (-std=c++11) link_dirs: List[PathLike] = None, link_libs: List[str] = None, working_dir: PathLike = None, @@ -37,20 +38,28 @@ def compile_hpvm_c( ): from subprocess import check_output - passes = ["LLVMBuildDFG", "LLVMInPlaceDFGAnalysis"] - pass_flags = ["buildDFG", "inplace"] + passes = ["LLVMBuildDFG"] + pass_flags = ["buildDFG"] if tensor_target == "tensor": if conf_file is None: raise ValueError("conf_file must be defined when tensor_target=='tensor'.") - passes += ["LLVMFuseHPVMTensorNodes", "LLVMDFG2LLVM_WrapperAPI"] - pass_flags += ["hpvm-fuse", "dfg2llvm-wrapperapi", f"configuration-inputs-filename={conf_file}"] + passes += ["LLVMInPlaceDFGAnalysis", "LLVMFuseHPVMTensorNodes", "LLVMDFG2LLVM_WrapperAPI"] + pass_flags += [ + "inplace", "hpvm-fuse", "dfg2llvm-wrapperapi", + f"configuration-inputs-filename={conf_file}" + ] elif tensor_target == "cudnn": - passes += ["LLVMDFG2LLVM_CUDNN"] - pass_flags += ["dfg2llvm-cudnn"] + passes += ["LLVMInPlaceDFGAnalysis", "LLVMDFG2LLVM_CUDNN"] + pass_flags += ["inplace", "dfg2llvm-cudnn"] + elif tensor_target is None: + passes += ["LLVMLocalMem"] + pass_flags += ["localmem"] + else: + raise ValueError(f"Tensor target {tensor_target} not recognized") if opencl: passes += ["LLVMDFG2LLVM_OpenCL"] pass_flags += ["dfg2llvm-opencl"] - passes += ["LLVMDFG2LLVM_CPU", "LLVMClearDFG", "LLVMGenHPVM"] + passes += ["LLVMDFG2LLVM_CPU", "LLVMClearDFG"] pass_flags += ["dfg2llvm-cpu", "clearDFG"] working_dir = Path(working_dir or ".") @@ -63,15 +72,15 @@ def compile_hpvm_c( hpvm_ll_file = working_dir / f"{name_stem}.hpvm.ll" llvm_ll_file = working_dir / f"{name_stem}.llvm.ll" hpvm_rt_linked_file = working_dir / f"{name_stem}.linked.bc" + link_bitcode_ = [Path(bc) for bc in (link_bitcode or [])] commands = [ - hpvm_c_to_ll(hpvm_src, ll_file, include, macro, flags, optim_level, std), + hpvm_c_to_ll(hpvm_src, ll_file, include, macro, flags, optim_level, is_cpp, std), opt_codegen_hpvm(ll_file, hpvm_ll_file), _run_opt(hpvm_ll_file, llvm_ll_file, passes, pass_flags), - link_hpvm_rt(llvm_ll_file, hpvm_rt_linked_file), + link_hpvm_rt(link_bitcode_ + [llvm_ll_file], hpvm_rt_linked_file), ] - link_bitcode_ = [Path(bc) for bc in (link_bitcode or [])] commands.append( - link_binary(link_bitcode_ + [hpvm_rt_linked_file], output_file, link_dirs, link_libs) + link_binary(hpvm_rt_linked_file, output_file, link_dirs, link_libs) ) for command in commands: print(" ".join(command)) @@ -85,15 +94,19 @@ def hpvm_c_to_ll( macros: List[str] = None, flags: List[str] = None, optim_level: str = "0", # -O0 - std: str = "c++11", # --std=c++11 + is_cpp: bool = True, # otherwise is C + std: str = None, # --std=c++11 ) -> List[str]: includes = (extra_includes or []) + TRT_INCLUDE_DIRS includes_s = [f"-I{path}" for path in includes] macros = [f"-D{macro}" for macro in (macros or [])] flags = [f"-f{flg}" for flg in (flags or [])] + if std: + flags.append(f"-std={std}") + clang = "clang++" if is_cpp else "clang" return [ - str(LLVM_BUILD_BIN / "clang++"), *includes_s, *flags, - f"-O{optim_level}", f"-std={std}", "-emit-llvm", "-S", + str(LLVM_BUILD_BIN / clang), *includes_s, *flags, *macros, + f"-O{optim_level}", "-emit-llvm", "-S", str(src_file), "-o", str(target_file) ] @@ -102,24 +115,24 @@ def opt_codegen_hpvm(src_file: PathLike, target_file: PathLike) -> List[str]: return _run_opt(src_file, target_file, ["LLVMGenHPVM"], ["genhpvm", "globaldce"]) -def link_hpvm_rt(src_file: PathLike, target_file: PathLike) -> List[str]: - return [str(LLVM_BUILD_BIN / "llvm-link"), str(src_file), HPVM_RT_PATH, "-o", str(target_file)] +def link_hpvm_rt(bitcodes: List[PathLike], target_file: PathLike) -> List[str]: + bitcodes_s = [str(bc) for bc in bitcodes] + return [str(LLVM_BUILD_BIN / "llvm-link"), *bitcodes_s, HPVM_RT_PATH, "-S", "-o", str(target_file)] def link_binary( - src_files: List[PathLike], + src_file: PathLike, target_file: PathLike, extra_link_dirs: List[PathLike] = None, extra_link_libs: List[str] = None ) -> List[str]: - src_files_s = [str(file) for file in src_files] link_dirs, link_libs = _link_args(extra_link_dirs or [], extra_link_libs or []) linker_dir_flags = [] for path in link_dirs: linker_dir_flags.extend([f"-L{path}", f"-Wl,-rpath={path}"]) linker_lib_flags = [f"-l{lib}" for lib in link_libs] return [ - str(LLVM_BUILD_BIN / "clang++"), *src_files_s, + str(LLVM_BUILD_BIN / "clang++"), str(src_file), "-o", str(target_file), *linker_dir_flags, *linker_lib_flags ] @@ -168,6 +181,10 @@ HPVM-C code must be single file, but additional bitcode file can be linked toget See option -b for that.""" ) parser.add_argument("output_file", type=Path, help="Path to generate binary to") + parser.add_argument( + "-x", type=str, metavar="language", default="c++", + help="Treat input file as having type <language>", + ) parser.add_argument( "-b", "--link-bitcode", @@ -213,7 +230,7 @@ See option -b for that.""" help="[clang emit-llvm] Optimization level" ) parser.add_argument( - "--std", type=str, default="c++11", + "--std", type=str, help="[clang emit-llvm] Language standard to compile for. Double dashes (--std, not -std)." ) @@ -231,6 +248,12 @@ See option -b for that.""" if args.tensor_target == "tensor": if args.conf_file is None: parser.error('Tensor target "tensor" requires --conf-file argument') + if args.x == "c": + args.is_cpp = False + elif args.x == "c++": + args.is_cpp = True + else: + parser.error(f"Language mode {args.x} not supported yet -- only c or c++") return args @@ -241,4 +264,5 @@ if __name__ == "__main__": args["optim_level"] = args.pop("O") args["link_dirs"] = args.pop("L") args["link_libs"] = args.pop("l") + args.pop("x") compile_hpvm_c(**args)