Skip to content
Snippets Groups Projects
Commit a6309ad6 authored by Yifan Zhao's avatar Yifan Zhao
Browse files

Improved torch2hpvm setup script

parent 0d92db2e
No related branches found
No related tags found
No related merge requests found
build/ build/
dist/ dist/
onnx2hpvm.egg-info/ *.egg-info/
.ipynb_checkpoints/
name: onnxfront
channels:
- defaults
dependencies:
- pip
- python=3.8.5=h7579374_1
- jinja2
- networkx
- pip:
- onnx==1.8.0
from setuptools import setup from setuptools import setup
setup( setup(
name='torch2hpvm', name="torch2hpvm",
version='1.0', version="1.0",
description='PyTorch frontend for HPVM', description="PyTorch frontend for HPVM",
author='Yuanjing Shi, Yifan Zhao', author="Yuanjing Shi, Yifan Zhao",
author_email='ys26@illinois.edu, yifanz16@illinois.edu', author_email="ys26@illinois.edu, yifanz16@illinois.edu",
packages=['torch2hpvm'], packages=["torch2hpvm"],
install_requires=[], install_requires=["jinja2>=2.11", "networkx>=2.5", "onnx>=1.8.0"],
entry_points={"console_scripts": ["torch2hpvm=torch2hpvm:main"]},
) )
from .compile import compile
from .__main__ import main
import os
from pathlib import Path from pathlib import Path
from typing import Optional
import onnx from .compile import compile
def check_version(model, new_version):
try:
opset = model.opset_import[0].version if model.opset_import else 1
except AttributeError:
opset = 1 # default opset version set to 1 if not specified
if opset != new_version:
from onnx import version_converter
try:
converted_model = version_converter.convert_version(model, new_version)
return converted_model
except RuntimeError as e:
raise RuntimeError(
f"Current version {opset} of ONNX model not supported!\n"
f"Conversion failed with message below: \n{e}"
)
return model
def compile(
onnx_file: Path,
output_dir: Path,
input_size: int,
prefix: Optional[str],
batch_size: Optional[int],
opset: Optional[int],
hpvmc: bool,
):
from frontend.graph_builder import DFG
from frontend.codegen_tensor import TensorCodeGen
from frontend.codegen_hpvm import HpvmCodeGen
model = onnx.load(onnx_file)
if opset is not None:
model = check_version(model, opset)
model = onnx.shape_inference.infer_shapes(model)
dfg = DFG(model.graph)
if hpvmc:
hpvm_code_gen = HpvmCodeGen(dfg, output_dir, input_size, batch_size, prefix)
hpvm_code_gen.compile()
else:
tensor_code_gen = TensorCodeGen(dfg, output_dir, input_size, batch_size, prefix)
tensor_code_gen.compile()
dfg.dump_weights(output_dir)
def parse_args(): def parse_args():
...@@ -96,12 +50,6 @@ hpvmc: HPVM C Interface. Default value is hpvmc.""", ...@@ -96,12 +50,6 @@ hpvmc: HPVM C Interface. Default value is hpvmc.""",
def main(): def main():
import os
args = parse_args() args = parse_args()
os.makedirs(args.output_dir, exist_ok=True) os.makedirs(args.output_dir, exist_ok=True)
compile(**vars(args)) compile(**vars(args))
if __name__ == "__main__":
main()
from pathlib import Path
from typing import Optional, Union
import onnx
from onnx import version_converter
from .codegen_hpvm import HpvmCodeGen
from .codegen_tensor import TensorCodeGen
from .graph_builder import DFG
PathLike = Union[Path, str]
def check_version(model, new_version):
try:
opset = model.opset_import[0].version if model.opset_import else 1
except AttributeError:
opset = 1 # default opset version set to 1 if not specified
if opset != new_version:
try:
converted_model = version_converter.convert_version(model, new_version)
return converted_model
except RuntimeError as e:
raise RuntimeError(
f"Current version {opset} of ONNX model not supported!\n"
f"Conversion failed with message below: \n{e}"
)
return model
def compile(
onnx_file: Path,
output_dir: Path,
input_size: int,
prefix: Optional[str],
batch_size: Optional[int],
opset: Optional[int],
hpvmc: bool,
):
model = onnx.load(onnx_file)
if opset is not None:
model = check_version(model, opset)
model = onnx.shape_inference.infer_shapes(model)
dfg = DFG(model.graph)
if hpvmc:
hpvm_code_gen = HpvmCodeGen(dfg, output_dir, input_size, batch_size, prefix)
hpvm_code_gen.compile()
else:
tensor_code_gen = TensorCodeGen(dfg, output_dir, input_size, batch_size, prefix)
tensor_code_gen.compile()
dfg.dump_weights(output_dir)
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment