Skip to content
Snippets Groups Projects
main.py 1.93 KiB
import os
import sys
import numpy as np
import onnx
import glob
#from onnxruntime.backend.backend import OnnxRuntimeBackend as backend

onnx_file_dir = "../models/keras/lenet.onnx"
src_emit_dir = "./test_src"

def check_version(model, new_version):
    try:
        opset = model.opset_import[0].version if model.opset_import else 1
    except AttributeError:
        opset = 1  # default opset version set to 1 if not specified
    print("opset version: ", opset)
    if opset != new_version:
        #print('The model before conversion:\n{}'.format(model))

        # A full list of supported adapters can be found here:
        # https://github.com/onnx/onnx/blob/master/onnx/version_converter.py#L21
        # Apply the version conversion on the original model
        from onnx import version_converter
        try:
            converted_model = version_converter.convert_version(model, new_version)
            return converted_model
        except RuntimeError as e:
            print("Current version {} of ONNX model not supported!".format(opset))
            print("Coversion failed with message below:")
            raise e
        #print('The model after conversion:\n{}'.format(converted_model))
    return model

def compile(model):
    # TODO: make this in constant
    # make a cmd option, default value -> constant
    weights_dir = src_emit_dir
    opset_version_default = 11
    # test_data_dir = '../models/mnist/test_data_set_0'
    # converted_model = convert_version(model)
    # model = check_version(model, 11)
    from graph_builder import GraphBuilder
    from graph_codegen import GraphCodeGen
    gBuilder = GraphBuilder(model, None, "float32", weights_dir)
    gCodegen = GraphCodeGen(gBuilder.build_graph(), weights_dir)
    gCodegen.compile()

def main():
    # TODO: Put it in args
    model = onnx.load(onnx_file_dir)
    # model = onnx.load('../models/keras/vgg16_cifar10.onnx')
    compile(model)

if __name__ == "__main__":
    main()