From 5e671e05b2a6ca1a72a5de80b72af327dd0b0545 Mon Sep 17 00:00:00 2001
From: Yifan Zhao <yifanz16@illinois.edu>
Date: Tue, 23 Mar 2021 15:56:49 -0500
Subject: [PATCH] Grab data from Illinois data bank instead of git lfs

---
 .gitattributes                      |  1 -
 hpvm/scripts/hpvm_installer.py      | 43 +++++++++++++++++++++++++----
 hpvm/test/dnn_benchmarks/.gitignore |  1 +
 3 files changed, 39 insertions(+), 6 deletions(-)
 delete mode 100644 .gitattributes
 create mode 100644 hpvm/test/dnn_benchmarks/.gitignore

diff --git a/.gitattributes b/.gitattributes
deleted file mode 100644
index ff4f28aed4..0000000000
--- a/.gitattributes
+++ /dev/null
@@ -1 +0,0 @@
-hpvm/test/dnn_benchmarks/model_params/**/*.bin filter=lfs diff=lfs merge=lfs -text
diff --git a/hpvm/scripts/hpvm_installer.py b/hpvm/scripts/hpvm_installer.py
index 11ad304552..4d1fefdc7a 100755
--- a/hpvm/scripts/hpvm_installer.py
+++ b/hpvm/scripts/hpvm_installer.py
@@ -1,7 +1,7 @@
 #!/usr/bin/env python3
 from pathlib import Path
 from argparse import ArgumentParser, Namespace
-from subprocess import check_call
+from subprocess import CalledProcessError, check_call
 from os import makedirs, chdir, environ
 
 VERSION = "9.0.0"
@@ -12,11 +12,15 @@ CLANG_TARBALL = f"{CLANG_DIR}.tar.xz"
 LLVM_DIR = f"llvm-{VERSION}.src"
 LLVM_TARBALL = f"{LLVM_DIR}.tar.xz"
 
-ROOT_DIR = Path.cwd()
+ROOT_DIR = (Path(__file__).parent / "..").absolute()
 BUILD_DIR = ROOT_DIR / "build"
 TEST_DIR = ROOT_DIR / "test"
 LLVM_LIT = BUILD_DIR / "bin/llvm-lit"
 
+MODEL_PARAMS_DIR = TEST_DIR / "dnn_benchmarks/model_params"
+MODEL_PARAMS_TAR = ROOT_DIR / "model_params.tar.gz"
+MODEL_PARAMS_LINK = "https://databank.illinois.edu/datafiles/o3izd/download"
+
 LINKS = [
     "CMakeLists.txt",
     "cmake",
@@ -63,6 +67,9 @@ def parse_args():
     parser.add_argument(
         "-r", "--run-tests", action="store_true", help="Build and run test cases"
     )
+    parser.add_argument(
+        "--no-params", action="store_true", help="Don't download DNN model parameters"
+    )
     return parser.parse_args()
 
 
@@ -74,11 +81,11 @@ def prompt_args():
     def parse_int(s: str):
         try:
             v = int(s)
-            return v
         except ValueError:
             return None
         if v <= 0:
             return None
+        return v
 
     def parse_targets(s: str):
         if " " in s:
@@ -124,7 +131,7 @@ def print_args(args):
 
 def check_download_llvm_clang():
     if Path("llvm/").is_dir():
-        print("Found LLVM, not extracting it again.")
+        print("Found LLVM directory, not extracting it again.")
     else:
         if Path(LLVM_TARBALL).is_file():
             print(f"Found {LLVM_TARBALL}, not downloading it again.")
@@ -141,7 +148,7 @@ def check_download_llvm_clang():
     environ["LLVM_SRC_ROOT"] = str(ROOT_DIR / "llvm")
 
     if (tools / "clang/").is_dir():
-        print("Found clang, not downloading it again.")
+        print("Found clang directory, not extracting it again.")
         return
     chdir(tools)
     print(f"Downloading {CLANG_TARBALL}...")
@@ -155,6 +162,30 @@ def check_download_llvm_clang():
     chdir(ROOT_DIR)
 
 
+def check_download_model_params():
+    if MODEL_PARAMS_DIR.is_dir():
+        print("Found model parameters, not extracting it again.")
+        return
+    if MODEL_PARAMS_TAR.is_file():
+        print(f"Found {MODEL_PARAMS_TAR}, not downloading it again.")
+    else:
+        print(f"Downloading DNN model parameters: {MODEL_PARAMS_TAR}...")
+        print(f"=============================")
+        check_call([WGET, MODEL_PARAMS_LINK, "-O", MODEL_PARAMS_TAR])
+    print(f"Extracting DNN model parameters {MODEL_PARAMS_TAR} => {MODEL_PARAMS_DIR}...")
+    # Decompression is pretty time-consuming so we try to show a progress bar:
+    try:
+        check_call(f"pv {MODEL_PARAMS_TAR} | tar xz", shell=True)
+    except CalledProcessError:
+        # Maybe `pv` is not installed. Fine, we'll run without progress bar.
+        print(">> 'pv' is not installed, no progress bar will be shown during decompression.")
+        print(">> Decompression ongoing...")
+        check_call(["tar", "xzf", MODEL_PARAMS_TAR])
+    check_call(["mv", "model_params", MODEL_PARAMS_DIR])
+    if MODEL_PARAMS_TAR.is_file():
+        MODEL_PARAMS_TAR.unlink()
+
+
 def link_and_patch():
     from os import symlink
 
@@ -236,6 +267,8 @@ def main():
     print_args(args)
     check_download_llvm_clang()
     link_and_patch()
+    if not args.no_params:
+        check_download_model_params()
     maybe_build(not args.no_build, args.parallel, args.targets, args.run_tests)
     if args.run_tests:
         run_tests()
diff --git a/hpvm/test/dnn_benchmarks/.gitignore b/hpvm/test/dnn_benchmarks/.gitignore
new file mode 100644
index 0000000000..6363621cf1
--- /dev/null
+++ b/hpvm/test/dnn_benchmarks/.gitignore
@@ -0,0 +1 @@
+model_params/
-- 
GitLab