Skip to content
Snippets Groups Projects
Commit fc569c07 authored by Yifan Zhao's avatar Yifan Zhao
Browse files

Updated hpvm_installer for a few things:

Added nvdla/sw as a submodule;
Relaxed python version req to be >= 3.6;
Only install torch2hpvm
Don't download DNN params
parent f275ab1a
No related branches found
No related tags found
No related merge requests found
...@@ -2,3 +2,6 @@ ...@@ -2,3 +2,6 @@
path = hpvm/projects/predtuner path = hpvm/projects/predtuner
url = ../predtuner.git url = ../predtuner.git
branch = hpvm branch = hpvm
[submodule "hpvm/projects/sw"]
path = hpvm/projects/sw
url = https://github.com/nvdla/sw.git
Subproject commit 79538ba1b52b040a4a4645f630e457fa01839e90
...@@ -18,9 +18,6 @@ MODEL_PARAMS_TAR = Path("model_params.tar.gz") ...@@ -18,9 +18,6 @@ MODEL_PARAMS_TAR = Path("model_params.tar.gz")
MODEL_PARAMS_DIR = ROOT_DIR / "test/dnn_benchmarks/model_params" MODEL_PARAMS_DIR = ROOT_DIR / "test/dnn_benchmarks/model_params"
MODEL_PARAMS_LINK = "https://databank.illinois.edu/datafiles/o3izd/download" MODEL_PARAMS_LINK = "https://databank.illinois.edu/datafiles/o3izd/download"
NVDLA_URL = "https://github.com/nvdla/sw.git"
NVDLA_DIR = ROOT_DIR / "projects/sw"
LINKS = [ LINKS = [
"CMakeLists.txt", "CMakeLists.txt",
"cmake", "cmake",
...@@ -34,14 +31,9 @@ MAKE_TARGETS = ["hpvm-clang"] ...@@ -34,14 +31,9 @@ MAKE_TARGETS = ["hpvm-clang"]
MAKE_TEST_TARGETS = ["check-hpvm-dnn", "check-hpvm-pass"] MAKE_TEST_TARGETS = ["check-hpvm-dnn", "check-hpvm-pass"]
# Relative to project root which is __file__.parent.parent # Relative to project root which is __file__.parent.parent
PY_PACKAGES = [ PY_PACKAGES = ["projects/torch2hpvm"]
"projects/hpvm-profiler",
"projects/predtuner",
"projects/torch2hpvm",
"projects/keras",
]
PYTHON_REQ = ((3, 6), (3, 7)) # This means >= 3.6, < 3.7 PYTHON_REQ = (3, 6) # This means >= 3.6
def parse_args(args=None): def parse_args(args=None):
...@@ -90,12 +82,6 @@ def parse_args(args=None): ...@@ -90,12 +82,6 @@ def parse_args(args=None):
parser.add_argument( parser.add_argument(
"-r", "--run-tests", action="store_true", help="Build and run test cases" "-r", "--run-tests", action="store_true", help="Build and run test cases"
) )
parser.add_argument(
"--no-pypkg", action="store_true", help="Don't build the HPVM Python Packages"
)
parser.add_argument(
"--no-params", action="store_true", help="Don't download DNN model parameters"
)
parser.add_argument( parser.add_argument(
"cmake_args", "cmake_args",
type=str, type=str,
...@@ -163,14 +149,6 @@ Arguments: """ ...@@ -163,14 +149,6 @@ Arguments: """
if args.cmake_args.strip() != "": if args.cmake_args.strip() != "":
args.cmake_args = [f"-{arg}" for arg in args.cmake_args.split(" ")] args.cmake_args = [f"-{arg}" for arg in args.cmake_args.split(" ")]
args.no_pypkg = not input_with_check(
"Install HPVM Python Packages (recommended)? [y/n]: ",
parse_yn,
"Please enter y or n",
)
args.no_params = not input_with_check(
"Download DNN weights (recommended)? [y/n]: ", parse_yn, "Please enter y or n"
)
args.run_tests = input_with_check( args.run_tests = input_with_check(
"Build and run tests? [y/n]: ", parse_yn, "Please enter y or n" "Build and run tests? [y/n]: ", parse_yn, "Please enter y or n"
) )
...@@ -185,7 +163,6 @@ def print_args(args): ...@@ -185,7 +163,6 @@ def print_args(args):
print(f" Build system: {build_sys}") print(f" Build system: {build_sys}")
print(f" Threads: {args.parallel}") print(f" Threads: {args.parallel}")
print(f" Targets: {args.targets}") print(f" Targets: {args.targets}")
print(f" Download DNN weights: {not args.no_params}")
print(f" Run tests: {args.run_tests}") print(f" Run tests: {args.run_tests}")
print(f" CMake arguments: {args.cmake_args}") print(f" CMake arguments: {args.cmake_args}")
...@@ -193,14 +170,12 @@ def print_args(args): ...@@ -193,14 +170,12 @@ def print_args(args):
def check_python_version(): def check_python_version():
from sys import version_info, version, executable from sys import version_info, version, executable
lowest, highest = PYTHON_REQ if version_info < PYTHON_REQ:
if not (lowest <= version_info < highest): lowest_str = ".".join([str(n) for n in PYTHON_REQ])
lowest_str = ".".join([str(n) for n in lowest])
highest_str = ".".join([str(n) for n in highest])
version_short_str = ".".join([str(n) for n in version_info]) version_short_str = ".".join([str(n) for n in version_info])
raise RuntimeError( raise RuntimeError(
f"You are using Python {version_short_str}, unsupported by HPVM. " f"You are using Python {version_short_str}, unsupported by HPVM. "
f"HPVM requires Python version '{lowest_str} <= version < {highest_str}'.\n" f"HPVM requires Python version 'version >= {lowest_str}'.\n"
f"(Current Python binary: {executable})\n" f"(Current Python binary: {executable})\n"
f"Detailed version info:\n{version}" f"Detailed version info:\n{version}"
) )
...@@ -268,15 +243,6 @@ def check_download_model_params(): ...@@ -268,15 +243,6 @@ def check_download_model_params():
MODEL_PARAMS_TAR.unlink() MODEL_PARAMS_TAR.unlink()
def check_download_nvdla_sw():
if NVDLA_DIR.is_dir():
print("Found NVDLA compiler, not downloading it again.")
return
print(f"Downloading the NVDLA compiler into {NVDLA_DIR}...")
print(f"=============================")
check_call(["git", "clone", NVDLA_URL, NVDLA_DIR])
def link_and_patch(): def link_and_patch():
from os import symlink from os import symlink
...@@ -370,16 +336,13 @@ def main(): ...@@ -370,16 +336,13 @@ def main():
# Don't parse args if no args given -- use prompt mode # Don't parse args if no args given -- use prompt mode
args = prompt_args() if len(argv) == 1 else parse_args() args = prompt_args() if len(argv) == 1 else parse_args()
if not args.no_pypkg: # Help users pull submodule in case they forgot
check_python_version() check_call(["git", "submodule", "update", "--init", "--recursive"])
check_python_version()
print_args(args) print_args(args)
check_download_nvdla_sw()
check_download_llvm_clang() check_download_llvm_clang()
link_and_patch() link_and_patch()
if not args.no_params: install_py_packages()
check_download_model_params()
if not args.no_pypkg:
install_py_packages()
if args.no_build: if args.no_build:
print( print(
""" """
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment