diff --git a/hpvm/projects/onnx/frontend/config.py b/hpvm/projects/onnx/frontend/config.py deleted file mode 100644 index 8eac060c2838aca3820a83c05bf7633d6ecb9146..0000000000000000000000000000000000000000 --- a/hpvm/projects/onnx/frontend/config.py +++ /dev/null @@ -1,7 +0,0 @@ -model_name = "lenet" -compile_type = 1 # 0 for HPVM Tensor Runtime, 1 for HPVM C Interface -input_size = [1,2,3,4] -onnx_file_dir = "../models/keras/lenet.onnx" -opset_version_default = 10 -src_emit_dir = "./test_src" - diff --git a/hpvm/projects/onnx/frontend/main.py b/hpvm/projects/onnx/frontend/main.py index 248c52fdd3f93f20a08bec1e0ed8939dc838828a..683ca3bdae6fc7648ddb9e26307f50fd9a21a272 100644 --- a/hpvm/projects/onnx/frontend/main.py +++ b/hpvm/projects/onnx/frontend/main.py @@ -1,49 +1,99 @@ -import os -import sys -import numpy as np +from pathlib import Path +from typing import Iterable, Optional import onnx -import glob + def check_version(model, new_version): try: opset = model.opset_import[0].version if model.opset_import else 1 except AttributeError: opset = 1 # default opset version set to 1 if not specified - print("opset version: ", opset) if opset != new_version: - #print('The model before conversion:\n{}'.format(model)) from onnx import version_converter + try: converted_model = version_converter.convert_version(model, new_version) return converted_model except RuntimeError as e: - print("Current version {} of ONNX model not supported!".format(opset)) - print("Coversion failed with message below:") - raise e - #print('The model after conversion:\n{}'.format(converted_model)) + raise RuntimeError( + f"Current version {opset} of ONNX model not supported!\n" + f"Conversion failed with message below: \n{e}" + ) return model -def compile(model): - from config import compile_type, input_size, opset_version_default, src_emit_dir + +def compile( + model, + input_size: Iterable[int], + output_dir: Path, + opset_version: Optional[int], + hpvmc: bool, +): from graph_builder import GraphBuilder from graph_codegen import GraphCodeGen from hpvm_codegen import HpvmCodeGen - weights_dir = src_emit_dir - model = check_version(model, opset_version_default) - graphBuilder = GraphBuilder(model, None, "float32", weights_dir) - if compile_type == 0: - graphCodeGen = GraphCodeGen(graphBuilder.build_graph(), weights_dir, input_size) - graphCodeGen.compile() - elif compile_type == 1: - hpvmCodeGen = HpvmCodeGen(graphBuilder.build_graph(), weights_dir) + + if opset_version is not None: + model = check_version(model, opset_version) + graphBuilder = GraphBuilder(model, None, "float32", output_dir) + if hpvmc: + hpvmCodeGen = HpvmCodeGen(graphBuilder.build_graph(), output_dir) hpvmCodeGen.compile() else: - raise ValueError("Wrong type of Compilation! Abort.") + graphCodeGen = GraphCodeGen(graphBuilder.build_graph(), output_dir, input_size) + graphCodeGen.compile() + + +def parse_args(): + import argparse + + parser = argparse.ArgumentParser(description="ONNX to HPVM-C") + parser.add_argument("onnx_file", type=Path, help="Path to input ONNX file") + parser.add_argument( + "-s", + "--input-size", + type=int, + required=True, + nargs="+", + help="""Size of input tensor to the model. +Usually 4 dim, including batch size. +For example: -s 1 3 32 32""", + ) + parser.add_argument( + "output_dir", + type=Path, + help="Output folder where source file and weight files are generated", + ) + parser.add_argument("--opset", type=int, help="ONNX opset version (enforced)") + parser.add_argument( + "-c", + "--compile-mode", + type=str, + choices=["tensor", "hpvmc"], + default="hpvmc", + help="""Output mode. +tensor: HPVM Tensor Runtime; +hpvmc: HPVM C Interface. Default value is hpvmc.""", + ) + + args = parser.parse_args() + args.hpvmc = args.compile_mode == "hpvmc" + return args + def main(): - from config import onnx_file_dir - model = onnx.load(onnx_file_dir) - compile(model) + import os + + args = parse_args() + os.makedirs(args.output_dir, exist_ok=True) + compile( + onnx.load(args.onnx_file), + args.input_size, + args.output_dir, + args.opset, + args.hpvmc, + ) + if __name__ == "__main__": main()