From 40712a3cd23566fe5df7cd23a14909b21c645fc8 Mon Sep 17 00:00:00 2001 From: shingjan <yjshi03@gmail.com> Date: Sun, 31 May 2020 15:47:25 -0500 Subject: [PATCH] some minor changes on main --- hpvm/projects/onnx/frontend/graph_builder.py | 8 ++++---- hpvm/projects/onnx/frontend/graph_ir.py | 12 +++++++++++ hpvm/projects/onnx/frontend/main.py | 21 ++++++++------------ 3 files changed, 24 insertions(+), 17 deletions(-) diff --git a/hpvm/projects/onnx/frontend/graph_builder.py b/hpvm/projects/onnx/frontend/graph_builder.py index 7ffe32c538..0659fe3fcd 100644 --- a/hpvm/projects/onnx/frontend/graph_builder.py +++ b/hpvm/projects/onnx/frontend/graph_builder.py @@ -3,11 +3,10 @@ from onnx import numpy_helper from tensor import InputTensor, WeightTensor from graph_ir import * -support_onnx_ops = {"DepthwiseConv" : [2], +support_onnx_ops = {#"DepthwiseConv" : [2], "Conv" : [2], # only 2d supported here "MatMul" : None, "MaxPool": [2], # only 2d supported here - "Activation" : None, "BatchNormalization" : None, "Flatten" : None, "Add" : None, @@ -187,6 +186,8 @@ class DFG(object): return SoftMaxNode(onnx_node) elif onnx_node.op_type == "Relu": return ReluNode(onnx_node) + elif onnx_node.op_type == "Tanh": + return TanhNode(onnx_node) elif onnx_node.op_type == "BatchNormalization": return BatchNormalizationNode(onnx_node) elif onnx_node.op_type == "Pad": @@ -196,5 +197,4 @@ class DFG(object): elif onnx_node.op_type == "Flatten": return FlattenNode(onnx_node) else: - raise ValueError("Unsupported operator type!") - sys.exit("Unsupported operator type!") + raise ValueError("Unsupported operator type: {}!".format(onnx_node.op_type)) diff --git a/hpvm/projects/onnx/frontend/graph_ir.py b/hpvm/projects/onnx/frontend/graph_ir.py index 2cfed0c78f..64796514bc 100644 --- a/hpvm/projects/onnx/frontend/graph_ir.py +++ b/hpvm/projects/onnx/frontend/graph_ir.py @@ -193,6 +193,18 @@ class ReluNode(DFGNode): self.inst_str += "tensorRelu(" + mapped_input_name + "); \n" return self.inst_str +class TanhNode(DFGNode): + def __init__(self, layer): + DFGNode.__init__(self, layer) + + def codegen(self, tensors): + cur_node = self.onnx_node + mapped_input_name = tensors[cur_node.input[0]].get_mapped_name() + mapped_output_name = tensors[cur_node.output[0]].get_mapped_name() + self.inst_str += "void* " + mapped_output_name + " = " + self.inst_str += "tensorTanh(" + mapped_input_name + "); \n" + return self.inst_str + class BatchNormalizationNode(DFGNode): def __init__(self, layer): diff --git a/hpvm/projects/onnx/frontend/main.py b/hpvm/projects/onnx/frontend/main.py index 258e1d23a5..b9fefdecd7 100644 --- a/hpvm/projects/onnx/frontend/main.py +++ b/hpvm/projects/onnx/frontend/main.py @@ -5,9 +5,9 @@ import onnx import glob #from onnxruntime.backend.backend import OnnxRuntimeBackend as backend -onnx_file_dir = "../models/keras/lenet.onnx" +onnx_file_dir = "../models/keras/alexnet.onnx" src_emit_dir = "./test_src" -opset_version_default = 11 +opset_version_default = 10 def check_version(model, new_version): try: @@ -33,25 +33,20 @@ def check_version(model, new_version): return model def compile(model): - # TODO: make this in constant - # make a cmd option, default value -> constant weights_dir = src_emit_dir - # test_data_dir = '../models/mnist/test_data_set_0' - # converted_model = convert_version(model) - # model = check_version(model, 11) + model = check_version(model, opset_version_default) from graph_builder import GraphBuilder from graph_codegen import GraphCodeGen from hpvm_codegen import HpvmCodeGen - gBuilder = GraphBuilder(model, None, "float32", weights_dir) - #gCodegen = GraphCodeGen(gBuilder.build_graph(), weights_dir) - hCodegen = HpvmCodeGen(gBuilder.build_graph(), weights_dir) - #gCodegen.compile() - hCodegen.compile() + graphBuilder = GraphBuilder(model, None, "float32", weights_dir) + #graphCodeGen = GraphCodeGen(gBuilder.build_graph(), weights_dir) + #graphCodeGen.compile() + hpvmCodeGen = HpvmCodeGen(graphBuilder.build_graph(), weights_dir) + hpvmCodeGen.compile() def main(): # TODO: Put it in args model = onnx.load(onnx_file_dir) - # model = onnx.load('../models/keras/vgg16_cifar10.onnx') compile(model) if __name__ == "__main__": -- GitLab