Skip to content
Snippets Groups Projects
Commit 0a93985e authored by Yifan Zhao's avatar Yifan Zhao
Browse files

Parse argument from cli instead of config.py

parent 0664ba1e
No related branches found
No related tags found
No related merge requests found
model_name = "lenet"
compile_type = 1 # 0 for HPVM Tensor Runtime, 1 for HPVM C Interface
input_size = [1,2,3,4]
onnx_file_dir = "../models/keras/lenet.onnx"
opset_version_default = 10
src_emit_dir = "./test_src"
import os
import sys
import numpy as np
from pathlib import Path
from typing import Iterable, Optional
import onnx
import glob
def check_version(model, new_version):
try:
opset = model.opset_import[0].version if model.opset_import else 1
except AttributeError:
opset = 1 # default opset version set to 1 if not specified
print("opset version: ", opset)
if opset != new_version:
#print('The model before conversion:\n{}'.format(model))
from onnx import version_converter
try:
converted_model = version_converter.convert_version(model, new_version)
return converted_model
except RuntimeError as e:
print("Current version {} of ONNX model not supported!".format(opset))
print("Coversion failed with message below:")
raise e
#print('The model after conversion:\n{}'.format(converted_model))
raise RuntimeError(
f"Current version {opset} of ONNX model not supported!\n"
f"Conversion failed with message below: \n{e}"
)
return model
def compile(model):
from config import compile_type, input_size, opset_version_default, src_emit_dir
def compile(
model,
input_size: Iterable[int],
output_dir: Path,
opset_version: Optional[int],
hpvmc: bool,
):
from graph_builder import GraphBuilder
from graph_codegen import GraphCodeGen
from hpvm_codegen import HpvmCodeGen
weights_dir = src_emit_dir
model = check_version(model, opset_version_default)
graphBuilder = GraphBuilder(model, None, "float32", weights_dir)
if compile_type == 0:
graphCodeGen = GraphCodeGen(graphBuilder.build_graph(), weights_dir, input_size)
graphCodeGen.compile()
elif compile_type == 1:
hpvmCodeGen = HpvmCodeGen(graphBuilder.build_graph(), weights_dir)
if opset_version is not None:
model = check_version(model, opset_version)
graphBuilder = GraphBuilder(model, None, "float32", output_dir)
if hpvmc:
hpvmCodeGen = HpvmCodeGen(graphBuilder.build_graph(), output_dir)
hpvmCodeGen.compile()
else:
raise ValueError("Wrong type of Compilation! Abort.")
graphCodeGen = GraphCodeGen(graphBuilder.build_graph(), output_dir, input_size)
graphCodeGen.compile()
def parse_args():
import argparse
parser = argparse.ArgumentParser(description="ONNX to HPVM-C")
parser.add_argument("onnx_file", type=Path, help="Path to input ONNX file")
parser.add_argument(
"-s",
"--input-size",
type=int,
required=True,
nargs="+",
help="""Size of input tensor to the model.
Usually 4 dim, including batch size.
For example: -s 1 3 32 32""",
)
parser.add_argument(
"output_dir",
type=Path,
help="Output folder where source file and weight files are generated",
)
parser.add_argument("--opset", type=int, help="ONNX opset version (enforced)")
parser.add_argument(
"-c",
"--compile-mode",
type=str,
choices=["tensor", "hpvmc"],
default="hpvmc",
help="""Output mode.
tensor: HPVM Tensor Runtime;
hpvmc: HPVM C Interface. Default value is hpvmc.""",
)
args = parser.parse_args()
args.hpvmc = args.compile_mode == "hpvmc"
return args
def main():
from config import onnx_file_dir
model = onnx.load(onnx_file_dir)
compile(model)
import os
args = parse_args()
os.makedirs(args.output_dir, exist_ok=True)
compile(
onnx.load(args.onnx_file),
args.input_size,
args.output_dir,
args.opset,
args.hpvmc,
)
if __name__ == "__main__":
main()
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment