main.py 1.93 KiB
import os
import sys
import numpy as np
import onnx
import glob
#from onnxruntime.backend.backend import OnnxRuntimeBackend as backend
onnx_file_dir = "../models/keras/lenet.onnx"
src_emit_dir = "./test_src"
def check_version(model, new_version):
try:
opset = model.opset_import[0].version if model.opset_import else 1
except AttributeError:
opset = 1 # default opset version set to 1 if not specified
print("opset version: ", opset)
if opset != new_version:
#print('The model before conversion:\n{}'.format(model))
# A full list of supported adapters can be found here:
# https://github.com/onnx/onnx/blob/master/onnx/version_converter.py#L21
# Apply the version conversion on the original model
from onnx import version_converter
try:
converted_model = version_converter.convert_version(model, new_version)
return converted_model
except RuntimeError as e:
print("Current version {} of ONNX model not supported!".format(opset))
print("Coversion failed with message below:")
raise e
#print('The model after conversion:\n{}'.format(converted_model))
return model
def compile(model):
# TODO: make this in constant
# make a cmd option, default value -> constant
weights_dir = src_emit_dir
opset_version_default = 11
# test_data_dir = '../models/mnist/test_data_set_0'
# converted_model = convert_version(model)
# model = check_version(model, 11)
from graph_builder import GraphBuilder
from graph_codegen import GraphCodeGen
gBuilder = GraphBuilder(model, None, "float32", weights_dir)
gCodegen = GraphCodeGen(gBuilder.build_graph(), weights_dir)
gCodegen.compile()
def main():
# TODO: Put it in args
model = onnx.load(onnx_file_dir)
# model = onnx.load('../models/keras/vgg16_cifar10.onnx')
compile(model)
if __name__ == "__main__":
main()