Skip to content
Snippets Groups Projects
Commit 40712a3c authored by shingjan's avatar shingjan
Browse files

some minor changes on main

parent f3ce696b
No related branches found
No related tags found
No related merge requests found
...@@ -3,11 +3,10 @@ from onnx import numpy_helper ...@@ -3,11 +3,10 @@ from onnx import numpy_helper
from tensor import InputTensor, WeightTensor from tensor import InputTensor, WeightTensor
from graph_ir import * from graph_ir import *
support_onnx_ops = {"DepthwiseConv" : [2], support_onnx_ops = {#"DepthwiseConv" : [2],
"Conv" : [2], # only 2d supported here "Conv" : [2], # only 2d supported here
"MatMul" : None, "MatMul" : None,
"MaxPool": [2], # only 2d supported here "MaxPool": [2], # only 2d supported here
"Activation" : None,
"BatchNormalization" : None, "BatchNormalization" : None,
"Flatten" : None, "Flatten" : None,
"Add" : None, "Add" : None,
...@@ -187,6 +186,8 @@ class DFG(object): ...@@ -187,6 +186,8 @@ class DFG(object):
return SoftMaxNode(onnx_node) return SoftMaxNode(onnx_node)
elif onnx_node.op_type == "Relu": elif onnx_node.op_type == "Relu":
return ReluNode(onnx_node) return ReluNode(onnx_node)
elif onnx_node.op_type == "Tanh":
return TanhNode(onnx_node)
elif onnx_node.op_type == "BatchNormalization": elif onnx_node.op_type == "BatchNormalization":
return BatchNormalizationNode(onnx_node) return BatchNormalizationNode(onnx_node)
elif onnx_node.op_type == "Pad": elif onnx_node.op_type == "Pad":
...@@ -196,5 +197,4 @@ class DFG(object): ...@@ -196,5 +197,4 @@ class DFG(object):
elif onnx_node.op_type == "Flatten": elif onnx_node.op_type == "Flatten":
return FlattenNode(onnx_node) return FlattenNode(onnx_node)
else: else:
raise ValueError("Unsupported operator type!") raise ValueError("Unsupported operator type: {}!".format(onnx_node.op_type))
sys.exit("Unsupported operator type!")
...@@ -193,6 +193,18 @@ class ReluNode(DFGNode): ...@@ -193,6 +193,18 @@ class ReluNode(DFGNode):
self.inst_str += "tensorRelu(" + mapped_input_name + "); \n" self.inst_str += "tensorRelu(" + mapped_input_name + "); \n"
return self.inst_str return self.inst_str
class TanhNode(DFGNode):
def __init__(self, layer):
DFGNode.__init__(self, layer)
def codegen(self, tensors):
cur_node = self.onnx_node
mapped_input_name = tensors[cur_node.input[0]].get_mapped_name()
mapped_output_name = tensors[cur_node.output[0]].get_mapped_name()
self.inst_str += "void* " + mapped_output_name + " = "
self.inst_str += "tensorTanh(" + mapped_input_name + "); \n"
return self.inst_str
class BatchNormalizationNode(DFGNode): class BatchNormalizationNode(DFGNode):
def __init__(self, layer): def __init__(self, layer):
......
...@@ -5,9 +5,9 @@ import onnx ...@@ -5,9 +5,9 @@ import onnx
import glob import glob
#from onnxruntime.backend.backend import OnnxRuntimeBackend as backend #from onnxruntime.backend.backend import OnnxRuntimeBackend as backend
onnx_file_dir = "../models/keras/lenet.onnx" onnx_file_dir = "../models/keras/alexnet.onnx"
src_emit_dir = "./test_src" src_emit_dir = "./test_src"
opset_version_default = 11 opset_version_default = 10
def check_version(model, new_version): def check_version(model, new_version):
try: try:
...@@ -33,25 +33,20 @@ def check_version(model, new_version): ...@@ -33,25 +33,20 @@ def check_version(model, new_version):
return model return model
def compile(model): def compile(model):
# TODO: make this in constant
# make a cmd option, default value -> constant
weights_dir = src_emit_dir weights_dir = src_emit_dir
# test_data_dir = '../models/mnist/test_data_set_0' model = check_version(model, opset_version_default)
# converted_model = convert_version(model)
# model = check_version(model, 11)
from graph_builder import GraphBuilder from graph_builder import GraphBuilder
from graph_codegen import GraphCodeGen from graph_codegen import GraphCodeGen
from hpvm_codegen import HpvmCodeGen from hpvm_codegen import HpvmCodeGen
gBuilder = GraphBuilder(model, None, "float32", weights_dir) graphBuilder = GraphBuilder(model, None, "float32", weights_dir)
#gCodegen = GraphCodeGen(gBuilder.build_graph(), weights_dir) #graphCodeGen = GraphCodeGen(gBuilder.build_graph(), weights_dir)
hCodegen = HpvmCodeGen(gBuilder.build_graph(), weights_dir) #graphCodeGen.compile()
#gCodegen.compile() hpvmCodeGen = HpvmCodeGen(graphBuilder.build_graph(), weights_dir)
hCodegen.compile() hpvmCodeGen.compile()
def main(): def main():
# TODO: Put it in args # TODO: Put it in args
model = onnx.load(onnx_file_dir) model = onnx.load(onnx_file_dir)
# model = onnx.load('../models/keras/vgg16_cifar10.onnx')
compile(model) compile(model)
if __name__ == "__main__": if __name__ == "__main__":
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment