diff --git a/hpvm/projects/onnx/frontend/graph_builder.py b/hpvm/projects/onnx/frontend/graph_builder.py index f5cc3fa158bead127eebaac50afd6f9ea783ae4e..192e6afbb8fdb108ab51eace5dd4269fd1be1064 100644 --- a/hpvm/projects/onnx/frontend/graph_builder.py +++ b/hpvm/projects/onnx/frontend/graph_builder.py @@ -1,6 +1,7 @@ import sys from onnx import numpy_helper from tensor import InputTensor, WeightTensor +from graph_ir import * support_onnx_ops = {"DepthwiseConv" : [2], "Conv" : [2], # only 2d supported here @@ -162,98 +163,37 @@ class DFG(object): def __init__(self, graph, tensors): self.graph = graph self.tensors = tensors - - def hasSingleInput(self, layer): - layer_name = layer.__class__.__name__ - return layer_name in self.singleInputLayers - - def hasMultipleInputs(self, layer): - layer_name = layer.__class__.__name__ - return layer_name in self.multiInputLayers - - def add_dfg_edge(self, inbound_node_name, dfg_node): - inbound_node_name = inbound_node_name.split(":")[0] - inbound_node_name = inbound_node_name.split("/")[0] - if inbound_node_name in self.node_map: - inbound_node = self.node_map[inbound_node_name] - print(inbound_node_name, " found!") - inbound_node.add_output(dfg_node) - dfg_node.add_input(inbound_node) - else: - print("--inbound node NOT FOUND!") - - def add_to_graph(self, layer): - dfg_node = DFGNode(layer) - if not self.root_set: - self.root_node = dfg_node - self.root_set = True # DFG root node is now set - - if self.hasMultipleInputs(layer): - for j in range(len(layer.input)): - print(type(layer.input[j])) - print(layer.input[j].op.name) - self.add_dfg_edge(layer.input[j].op.name, dfg_node) - else: - print(layer.input.name) - self.add_dfg_edge(layer.input.name, dfg_node) - # Adding DFG node to name mapping - self.node_map[layer.name] = dfg_node - - # Check if all predecessor nodes have been visited thus far - reverse - # postorder traversal - - def predVisited(self, cur_node, visited_nodes): - for input_node in cur_node.inputs: - if input_node.layer_name not in visited_nodes: - return False - # All predecessors are visited - return True - - def traverseNode(self, cur_node, visited_nodes): - # Skip visited nodes - if cur_node.layer_name in visited_nodes: - return - - if self.predVisited(cur_node, visited_nodes): - print(cur_node.layer_type) - print(cur_node.layer_name) - visited_nodes[cur_node.layer_name] = True - - # Invoking traversal on outbound nodes - for output_node in cur_node.outputs: - self.traverseNode(output_node, visited_nodes) - - # NOTE: Assuming that no outbound edges implies the last node in - # the graph - if len(cur_node.outputs) == 0: - self.last_node = cur_node - - # Build and Print the DFG in reverse postorder + self.nodes = list() def buildDFG(self): print("\n\n ****** Traversing and Printing DFG ******* \n\n") - visited_nodes = {} - # Starting traversal at the DFG root node - self.traverseNode(self.root_node, visited_nodes) + for node in self.graph.node: + self.nodes.append(self.emit_node(node)) # This should be the place where partial evaluation happens - def emitNode(self, layer): - if layer.op_type == "Conv": - return Conv2DNode() - elif layer.op_type == "Tanh": - pass - elif layer.op_type == "MaxPool": - pass - elif layer.op_type == "Flatten": - pass - elif layer.op_type == "MatMul": - pass - elif layer.op_type == "Add": - pass - elif layer.op_type == "SoftMax": - pass - elif layer.op_type == "Identity": - pass + def emit_node(self, onnx_node): + if onnx_node.op_type == "Conv": + return Conv2DNode(onnx_node) + elif onnx_node.op_type == "MaxPool": + return MaxPool2DNode(onnx_node) + elif onnx_node.op_type == "AveragePool": + return AveragePool2DNode(onnx_node) + elif onnx_node.op_type == "MatMul": + return MatMulNode(onnx_node) + elif onnx_node.op_type == "Add": + return AddNode(onnx_node) + elif onnx_node.op_type == "Softmax": + return SoftMaxNode(onnx_node) + elif onnx_node.op_type == "Relu": + return ReluNode(onnx_node) + elif onnx_node.op_type == "BatchNormalization": + return BatchNormalizationNode(onnx_node) + elif onnx_node.op_type == "Pad": + return PadNode(onnx_node) + elif onnx_node.op_type == "Identity": + return IdentityNode(onnx_node) + elif onnx_node.op_type == "Flatten": + return FlattenNode(onnx_node) else: raise ValueError("Unsupported operator type!") sys.exit("Unsupported operator type!") diff --git a/hpvm/projects/onnx/frontend/graph_codegen.py b/hpvm/projects/onnx/frontend/graph_codegen.py index 5524c276a86f10a125a9f502ce3a66772c619bb2..9489fc19808a53ea269dfddeb17339e15a749acb 100644 --- a/hpvm/projects/onnx/frontend/graph_codegen.py +++ b/hpvm/projects/onnx/frontend/graph_codegen.py @@ -6,18 +6,18 @@ from graph_builder import * #from graph_ir import * from tensor import * +skip_layer = ["Identity", "Flatten", "Pad"] class GraphCodeGen(object): def __init__(self, DFG, weights_dir, test_data=None, test_labels=None): self.program_str = "" self.graph = DFG.graph self.tensors = DFG.tensors - self.nodes = self.graph.node + self.nodes = DFG.nodes self.var_cnt = 0 self.weights_dir = weights_dir self.test_data = test_data - self.test_labels =test_labels - self.skip_layer = ["Identity", "Flatten", "Pad"] + self.test_labels = test_labels ################################################ # Aux functions @@ -131,7 +131,7 @@ class GraphCodeGen(object): inst_str += self.tensors[cur_node.input[4]].get_mapped_name() + ", " inst_str += str(epsilon) inst_str += "); \n" - elif cur_node.op_type in self.skip_layer: + elif cur_node.op_type in skip_layer: pass else: raise ValueError("Not supported op type:" + cur_node.op_type + "! \n") @@ -168,6 +168,22 @@ class GraphCodeGen(object): def emit_graph(self): for node in self.nodes: + # check if all inputs of this node is mapped + for i in cur_node.input: + self.tensors[i].get_mapped_name() + # set var name for output node + if len(cur_node.output) > 1: + raise ValueError("Output number for a single layer larger than 1!") + if cur_node.op_type in self.skip_layer: + mapped_output_name = self.get_last_var() + else: + mapped_output_name = self.get_new_var() + output_name = cur_node.output[0] + self.tensors[output_name].set_mapped_name(mapped_output_name) + self.program_str += node.codegen() + + def emit_graph2(self): + for node in self.graph.nodes: #pass self.program_str += self.emit_node_call(node) @@ -201,10 +217,10 @@ class GraphCodeGen(object): def emit_batch_loop(self, x_test=None): # FIXME: Dimensions from test data - N = x_test.shape[0] - C = x_test.shape[1] - H = x_test.shape[2] - W = x_test.shape[3] + N = 1#x_test.shape[0] + C = 1#x_test.shape[1] + H = 1#x_test.shape[2] + W = 1#x_test.shape[3] loop_str = "" loop_str += "\nstartMemTracking(); \n\n" @@ -258,7 +274,7 @@ class GraphCodeGen(object): self.emit_header() self.emit_weights() self.emit_batch_loop() - self.emit_graph() + self.emit_graph2() self.emit_batch_loop_end() self.emit_footer() # Write the program to source/disk diff --git a/hpvm/projects/onnx/frontend/graph_ir.py b/hpvm/projects/onnx/frontend/graph_ir.py index 71b577ff7fccbcea71fd89b4aeb2d26d407ddb6f..52c7569a354a7b69389a958083ee5e30ec47596c 100644 --- a/hpvm/projects/onnx/frontend/graph_ir.py +++ b/hpvm/projects/onnx/frontend/graph_ir.py @@ -1,18 +1,16 @@ ################################################ # Top Level DFGNode interface ################################################ -class DFGNode(object): - def add_output(self, output_node): - self.outputs.append(output_node) - def add_input(self, input_node): - self.inputs.append(input_node) - def __init__(self, layer): - self.inputs = list() - self.outputs = list() - self.name = layer.name - self.op_type = layer.op_type +class DFGNode(object): + def __init__(self, onnx_node): + self.node = onnx_node + self.name = onnx_node.name + self.op_type = onnx_node.op_type + + def codegen(self, tensors): + pass ''' @@ -21,6 +19,8 @@ e.g. HardSigmoid, LeakyRelu, PRelu, Pow, Reciprocal, Relu, Selu, Sigmoid, Softplus, Sqrt, ThresholdedRelu, Abs, Ceil, Elu, Floor, Neg ''' + + class ActivationNode(DFGNode): pass @@ -31,6 +31,8 @@ In other words, they are logical comparison operators e.g. And, Equal, Greater, GreaterOrEqual, Less, LessOrEqual, Or, Xor ''' + + class LogicalOpNode(DFGNode): pass @@ -43,110 +45,200 @@ class AddNode(DFGNode): def __init__(self, layer): DFGNode.__init__(self, layer) - def codegen(self): - pass + def codegen(self, tensors): + left_input = self.tensors[cur_node.input[0]].get_mapped_name() + right_input = self.tensors[cur_node.input[1]].get_mapped_name() + inst_str += "void* " + mapped_output_name + " = " + inst_str += "tensorAdd(" + left_input + ", " + right_input + "); \n" class MatMulNode(DFGNode): def __init__(self, layer): DFGNode.__init__(self, layer) - def codegen(self): - pass + def codegen(self, tensors): + left_input = self.tensors[cur_node.input[0]].get_mapped_name() + right_input = self.tensors[cur_node.input[1]].get_mapped_name() + inst_str += "void* " + mapped_output_name + " = " + inst_str += "tensorGemmGPU(" + left_input + \ + ", " + right_input + "); \n" class SoftMaxNode(DFGNode): def __init__(self, layer): DFGNode.__init__(self, layer) - def codegen(self): - pass + def codegen(self, tensors): + mapped_input_name = self.tensors[cur_node.input[0]].get_mapped_name() + inst_str += "void* " + mapped_output_name + " = " + inst_str += "tensorSoftmax(" + mapped_input_name + "); \n" class Conv2DNode(DFGNode): def __init__(self, layer): DFGNode.__init__(self, layer) - self.weights = layer.get_weights()[0] - print("\t", self.weights.shape) - self.use_bias = layer.use_bias - self.padding = layer.padding - self.strides = layer.strides - print("\t", self.strides) - print("\tPadding = ", self.padding) - - def codegen(self): - pass + def codegen(self, tensors): + cur_node = self.node + input_var_name = tensors[cur_node.input[0]].get_mapped_name() + weight = cur_node.input[1] + strides = list() + padding = 0 + for attr in cur_node.attribute: + if attr.name == "pads": + padding = attr.ints[0] + elif attr.name == "strides": + for stride in attr.ints: + strides.append(stride) + + inst_str += "void* " + mapped_output_name + " = " + inst_str += "tensorConvolution(" + input_var_name + ", " + inst_str += tensors[cur_node.input[1]].get_mapped_name() + ", " + inst_str += str(padding) + ", " + inst_str += str(padding) + ", " + inst_str += str(strides[0]) + ", " + inst_str += str(strides[1]) + ", " + inst_str += "1, 1); \n" + + # check if has bias add (Optional) + # in ONNX it is only in Conv + # in Keras bias add could exist in Dense + if len(cur_node.input) == 3: + mapped_output_name2 = self.get_new_var() + inst_str += "void* " + mapped_output_name2 + " = " + inst_str += "tensorAdd(" + mapped_output_name + ", " + inst_str += self.tensors[cur_node.input[2]].get_mapped_name() + "" + inst_str += "); \n" + self.tensors[output_name].set_mapped_name(mapped_output_name2) -class DepthwiseConv2DNode(DFGNode): + +class MaxPool2DNode(DFGNode): def __init__(self, layer): DFGNode.__init__(self, layer) - self.weights = layer.get_weights()[0] - print("\t", self.weights.shape) - self.use_bias = layer.use_bias - self.padding = layer.padding - self.strides = layer.strides - print("\t", self.strides) - print("\tPadding = ", self.padding) - - def codegen(self): - pass + def codegen(self, tensors): + input_var_name = self.tensors[cur_node.input[0]].get_mapped_name() + strides = list() + pool_size = list() + for attr in cur_node.attribute: + if attr.name == "kernel_shape": + for pool in attr.ints: + pool_size.append(pool) + elif attr.name == "strides": + for stride in attr.ints: + strides.append(stride) + # FIXME: Non-same padding is *NOT* currently supported + padding = 0 + pool_type = "0" + # tensorPooling(input, pool_type, pool_h, pool_w, v_pad, h_pad, v_stride, h_stride) + inst_str += "void* " + mapped_output_name + " = " + inst_str += "tensorPooling(" + input_var_name + "," + \ + pool_type + "," + str(pool_size[0]) + "," + str(pool_size[1]) + inst_str += "," + str(padding) + "," + str(padding) + \ + "," + str(strides[0]) + "," + str(strides[1]) + inst_str += "); \n" + + +class AveragePool2DNode(DFGNode): + def __init__(self, layer): + DFGNode.__init__(self, layer) -class DenseNode(DFGNode): + def codegen(self, tensors): + input_var_name = self.tensors[cur_node.input[0]].get_mapped_name() + strides = list() + pool_size = list() + for attr in cur_node.attribute: + if attr.name == "kernel_shape": + for pool in attr.ints: + pool_size.append(pool) + elif attr.name == "strides": + for stride in attr.ints: + strides.append(stride) + # FIXME: Non-same padding is *NOT* currently supported + padding = 0 + pool_type = "1" + # tensorPooling(input, pool_type, pool_h, pool_w, v_pad, h_pad, v_stride, h_stride) + inst_str += "void* " + mapped_output_name + " = " + inst_str += "tensorPooling(" + input_var_name + "," + \ + pool_type + "," + str(pool_size[0]) + "," + str(pool_size[1]) + inst_str += "," + str(padding) + "," + str(padding) + \ + "," + str(strides[0]) + "," + str(strides[1]) + inst_str += "); \n" + + +class ReluNode(DFGNode): def __init__(self, layer): DFGNode.__init__(self, layer) - self.weights = layer.get_weights()[0] - print("\t", self.weights.shape) - self.use_bias = layer.use_bias - def codegen(self): + def codegen(self, tensors): + mapped_input_name = self.tensors[cur_node.input[0]].get_mapped_name() + inst_str += "void* " + mapped_output_name + " = " + inst_str += "tensorRelu(" + mapped_input_name + "); \n" + + +class BatchNormalizationNode(DFGNode): + def __init__(self, layer): + DFGNode.__init__(self, layer) + + def codegen(self, tensors): + mapped_input_name = self.tensors[cur_node.input[0]].get_mapped_name() + epsilon = "" + for attr in cur_node.attribute: + if attr.name == "epsilon": + epsilon = str(attr.f) + inst_str += "void* " + mapped_output_name + " = " + inst_str += "tensorBatchNorm(" + mapped_input_name + ", " + inst_str += self.tensors[cur_node.input[1]].get_mapped_name() + ", " + inst_str += self.tensors[cur_node.input[2]].get_mapped_name() + ", " + inst_str += self.tensors[cur_node.input[3]].get_mapped_name() + ", " + inst_str += self.tensors[cur_node.input[4]].get_mapped_name() + ", " + inst_str += str(epsilon) + inst_str += "); \n" + + +class PadNode(DFGNode): + def __init__(self, layer): + DFGNode.__init__(self, layer) + + def codegen(self, tensors): pass -class MaxPool2DNode(DFGNode): +class IdentityNode(DFGNode): def __init__(self, layer): DFGNode.__init__(self, layer) - self.pool_size = layer.pool_size - self.strides = layer.strides - print("\t pool_size = ", self.pool_size) - print("\t strides = ", self.strides) - def codegen(self): + def codegen(self, tensors): pass -class AveragePooling2DNode(DFGNode): +class FlattenNode(DFGNode): def __init__(self, layer): DFGNode.__init__(self, layer) - self.pool_size = layer.pool_size - self.strides = layer.strides - print("\t pool_size = ", self.pool_size) - print("\t strides = ", self.strides) - def codegen(self): + def codegen(self, tensors): pass class ZeroPadding2DNode(DFGNode): def __init__(self, layer): DFGNode.__init__(self, layer) - print("***ZeroPaddding \n") - self.padding = layer.padding - print("padding = ", self.padding) - def codegen(self): + def codegen(self, tensors): pass -class BatchNormalizationNode(DFGNode): +class DepthwiseConv2DNode(DFGNode): + def __init__(self, layer): + DFGNode.__init__(self, layer) + + def codegen(self, tensors): + pass + + +class DenseNode(DFGNode): def __init__(self, layer): DFGNode.__init__(self, layer) - self.epsilon = layer.epsilon - self.beta = layer.beta - self.gamma = layer.gamma - self.moving_mean = layer.moving_mean - self.moving_variance = layer.moving_variance - def codegen(self): + def codegen(self, tensors): pass diff --git a/hpvm/projects/onnx/frontend/main.py b/hpvm/projects/onnx/frontend/main.py index 5a76e80de1067f019747a705dbacdf96291c7fb0..3a9344015b7082fd8e59f92fc6e4faf0e80a0b04 100644 --- a/hpvm/projects/onnx/frontend/main.py +++ b/hpvm/projects/onnx/frontend/main.py @@ -5,6 +5,9 @@ import onnx import glob #from onnxruntime.backend.backend import OnnxRuntimeBackend as backend +onnx_file_dir = "../models/keras/lenet.onnx" +src_emit_dir = "./test_src" + def check_version(model, new_version): try: opset = model.opset_import[0].version if model.opset_import else 1 @@ -29,7 +32,9 @@ def check_version(model, new_version): return model def compile(model): - weights_dir = './test_src' + # TODO: make this in constant + # make a cmd option, default value -> constant + weights_dir = src_emit_dir opset_version_default = 11 # test_data_dir = '../models/mnist/test_data_set_0' # converted_model = convert_version(model) @@ -41,7 +46,8 @@ def compile(model): gCodegen.compile() def main(): - model = onnx.load('../models/keras/lenet.onnx') + # TODO: Put it in args + model = onnx.load(onnx_file_dir) # model = onnx.load('../models/keras/vgg16_cifar10.onnx') compile(model)