Skip to content
Snippets Groups Projects
Commit 40712a3c authored by shingjan's avatar shingjan
Browse files

some minor changes on main

parent f3ce696b
No related branches found
No related tags found
No related merge requests found
......@@ -3,11 +3,10 @@ from onnx import numpy_helper
from tensor import InputTensor, WeightTensor
from graph_ir import *
support_onnx_ops = {"DepthwiseConv" : [2],
support_onnx_ops = {#"DepthwiseConv" : [2],
"Conv" : [2], # only 2d supported here
"MatMul" : None,
"MaxPool": [2], # only 2d supported here
"Activation" : None,
"BatchNormalization" : None,
"Flatten" : None,
"Add" : None,
......@@ -187,6 +186,8 @@ class DFG(object):
return SoftMaxNode(onnx_node)
elif onnx_node.op_type == "Relu":
return ReluNode(onnx_node)
elif onnx_node.op_type == "Tanh":
return TanhNode(onnx_node)
elif onnx_node.op_type == "BatchNormalization":
return BatchNormalizationNode(onnx_node)
elif onnx_node.op_type == "Pad":
......@@ -196,5 +197,4 @@ class DFG(object):
elif onnx_node.op_type == "Flatten":
return FlattenNode(onnx_node)
else:
raise ValueError("Unsupported operator type!")
sys.exit("Unsupported operator type!")
raise ValueError("Unsupported operator type: {}!".format(onnx_node.op_type))
......@@ -193,6 +193,18 @@ class ReluNode(DFGNode):
self.inst_str += "tensorRelu(" + mapped_input_name + "); \n"
return self.inst_str
class TanhNode(DFGNode):
def __init__(self, layer):
DFGNode.__init__(self, layer)
def codegen(self, tensors):
cur_node = self.onnx_node
mapped_input_name = tensors[cur_node.input[0]].get_mapped_name()
mapped_output_name = tensors[cur_node.output[0]].get_mapped_name()
self.inst_str += "void* " + mapped_output_name + " = "
self.inst_str += "tensorTanh(" + mapped_input_name + "); \n"
return self.inst_str
class BatchNormalizationNode(DFGNode):
def __init__(self, layer):
......
......@@ -5,9 +5,9 @@ import onnx
import glob
#from onnxruntime.backend.backend import OnnxRuntimeBackend as backend
onnx_file_dir = "../models/keras/lenet.onnx"
onnx_file_dir = "../models/keras/alexnet.onnx"
src_emit_dir = "./test_src"
opset_version_default = 11
opset_version_default = 10
def check_version(model, new_version):
try:
......@@ -33,25 +33,20 @@ def check_version(model, new_version):
return model
def compile(model):
# TODO: make this in constant
# make a cmd option, default value -> constant
weights_dir = src_emit_dir
# test_data_dir = '../models/mnist/test_data_set_0'
# converted_model = convert_version(model)
# model = check_version(model, 11)
model = check_version(model, opset_version_default)
from graph_builder import GraphBuilder
from graph_codegen import GraphCodeGen
from hpvm_codegen import HpvmCodeGen
gBuilder = GraphBuilder(model, None, "float32", weights_dir)
#gCodegen = GraphCodeGen(gBuilder.build_graph(), weights_dir)
hCodegen = HpvmCodeGen(gBuilder.build_graph(), weights_dir)
#gCodegen.compile()
hCodegen.compile()
graphBuilder = GraphBuilder(model, None, "float32", weights_dir)
#graphCodeGen = GraphCodeGen(gBuilder.build_graph(), weights_dir)
#graphCodeGen.compile()
hpvmCodeGen = HpvmCodeGen(graphBuilder.build_graph(), weights_dir)
hpvmCodeGen.compile()
def main():
# TODO: Put it in args
model = onnx.load(onnx_file_dir)
# model = onnx.load('../models/keras/vgg16_cifar10.onnx')
compile(model)
if __name__ == "__main__":
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment