From 0a93985e52499a290b2ded54e8df6dff9f40eaf5 Mon Sep 17 00:00:00 2001
From: Yifan Zhao <yifanz16@illinois.edu>
Date: Thu, 3 Dec 2020 01:23:19 -0600
Subject: [PATCH] Parse argument from cli instead of config.py

---
 hpvm/projects/onnx/frontend/config.py |  7 --
 hpvm/projects/onnx/frontend/main.py   | 98 ++++++++++++++++++++-------
 2 files changed, 74 insertions(+), 31 deletions(-)
 delete mode 100644 hpvm/projects/onnx/frontend/config.py

diff --git a/hpvm/projects/onnx/frontend/config.py b/hpvm/projects/onnx/frontend/config.py
deleted file mode 100644
index 8eac060c28..0000000000
--- a/hpvm/projects/onnx/frontend/config.py
+++ /dev/null
@@ -1,7 +0,0 @@
-model_name = "lenet"
-compile_type = 1 # 0 for HPVM Tensor Runtime, 1 for HPVM C Interface
-input_size = [1,2,3,4]
-onnx_file_dir = "../models/keras/lenet.onnx"
-opset_version_default = 10
-src_emit_dir = "./test_src"
-
diff --git a/hpvm/projects/onnx/frontend/main.py b/hpvm/projects/onnx/frontend/main.py
index 248c52fdd3..683ca3bdae 100644
--- a/hpvm/projects/onnx/frontend/main.py
+++ b/hpvm/projects/onnx/frontend/main.py
@@ -1,49 +1,99 @@
-import os
-import sys
-import numpy as np
+from pathlib import Path
+from typing import Iterable, Optional
 import onnx
-import glob
+
 
 def check_version(model, new_version):
     try:
         opset = model.opset_import[0].version if model.opset_import else 1
     except AttributeError:
         opset = 1  # default opset version set to 1 if not specified
-    print("opset version: ", opset)
     if opset != new_version:
-        #print('The model before conversion:\n{}'.format(model))
         from onnx import version_converter
+
         try:
             converted_model = version_converter.convert_version(model, new_version)
             return converted_model
         except RuntimeError as e:
-            print("Current version {} of ONNX model not supported!".format(opset))
-            print("Coversion failed with message below:")
-            raise e
-        #print('The model after conversion:\n{}'.format(converted_model))
+            raise RuntimeError(
+                f"Current version {opset} of ONNX model not supported!\n"
+                f"Conversion failed with message below: \n{e}"
+            )
     return model
 
-def compile(model):
-    from config import compile_type, input_size, opset_version_default, src_emit_dir
+
+def compile(
+    model,
+    input_size: Iterable[int],
+    output_dir: Path,
+    opset_version: Optional[int],
+    hpvmc: bool,
+):
     from graph_builder import GraphBuilder
     from graph_codegen import GraphCodeGen
     from hpvm_codegen import HpvmCodeGen
-    weights_dir = src_emit_dir
-    model = check_version(model, opset_version_default)
-    graphBuilder = GraphBuilder(model, None, "float32", weights_dir)
-    if compile_type == 0:
-        graphCodeGen = GraphCodeGen(graphBuilder.build_graph(), weights_dir, input_size)
-        graphCodeGen.compile()
-    elif compile_type == 1:
-        hpvmCodeGen = HpvmCodeGen(graphBuilder.build_graph(), weights_dir)
+
+    if opset_version is not None:
+        model = check_version(model, opset_version)
+    graphBuilder = GraphBuilder(model, None, "float32", output_dir)
+    if hpvmc:
+        hpvmCodeGen = HpvmCodeGen(graphBuilder.build_graph(), output_dir)
         hpvmCodeGen.compile()
     else:
-        raise ValueError("Wrong type of Compilation! Abort.")
+        graphCodeGen = GraphCodeGen(graphBuilder.build_graph(), output_dir, input_size)
+        graphCodeGen.compile()
+
+
+def parse_args():
+    import argparse
+
+    parser = argparse.ArgumentParser(description="ONNX to HPVM-C")
+    parser.add_argument("onnx_file", type=Path, help="Path to input ONNX file")
+    parser.add_argument(
+        "-s",
+        "--input-size",
+        type=int,
+        required=True,
+        nargs="+",
+        help="""Size of input tensor to the model.
+Usually 4 dim, including batch size.
+For example: -s 1 3 32 32""",
+    )
+    parser.add_argument(
+        "output_dir",
+        type=Path,
+        help="Output folder where source file and weight files are generated",
+    )
+    parser.add_argument("--opset", type=int, help="ONNX opset version (enforced)")
+    parser.add_argument(
+        "-c",
+        "--compile-mode",
+        type=str,
+        choices=["tensor", "hpvmc"],
+        default="hpvmc",
+        help="""Output mode.
+tensor: HPVM Tensor Runtime;
+hpvmc: HPVM C Interface. Default value is hpvmc.""",
+    )
+
+    args = parser.parse_args()
+    args.hpvmc = args.compile_mode == "hpvmc"
+    return args
+
 
 def main():
-    from config import onnx_file_dir
-    model = onnx.load(onnx_file_dir)
-    compile(model)
+    import os
+
+    args = parse_args()
+    os.makedirs(args.output_dir, exist_ok=True)
+    compile(
+        onnx.load(args.onnx_file),
+        args.input_size,
+        args.output_dir,
+        args.opset,
+        args.hpvmc,
+    )
+
 
 if __name__ == "__main__":
     main()
-- 
GitLab