diff --git a/hpvm/projects/onnx/frontend/codegen_tensor.py b/hpvm/projects/onnx/frontend/codegen_tensor.py index 132debc269e8be8259d47d74f4f4e906c600e81a..d382db94d3b7eef9ad2bc5bfb6aa77fea5f989c6 100644 --- a/hpvm/projects/onnx/frontend/codegen_tensor.py +++ b/hpvm/projects/onnx/frontend/codegen_tensor.py @@ -1,141 +1,63 @@ -import sys -import numpy as np -import os +from codegen_hpvm import emit_weights, get_input_args, make_c_identifier +from os import PathLike +from pathlib import Path +from typing import Dict, List, Tuple, Union +import jinja2 + +from graph_builder import DFG from tensor import WeightTensor -from utils import skip_layer + +TEMPLATE_FILE = "template_tensor.cpp" +loader = jinja2.FileSystemLoader(searchpath="./") +template_env = jinja2.Environment(loader=loader, trim_blocks=True) +template = template_env.get_template(TEMPLATE_FILE) + class TensorCodeGen(object): - def __init__(self, dfg, weights_dir, test_data_shape=None): - self.program_str = "" - self.graph = dfg.graph + def __init__(self, dfg: DFG, output_dir: PathLike, test_data_shape=None): self.tensors = dfg.tensors - self.nodes = dfg.nodes - self.var_cnt = -1 - self.weights_dir = weights_dir + self.dfg = dfg + self.var_count = 0 + self.output_dir = Path(output_dir) self.test_data_shape = test_data_shape + # self.variables is a "onnx name to our name" map + # Each value is (varname, bool) and the bool indicates + # "is root node input" or not. + IdenT = Union[str, int] + input_args = get_input_args(dfg.inputs, dfg.tensors) + self.variables: Dict[str, IdenT] = {k: make_c_identifier(k) for k in input_args} + print(self.variables) ################################################ # Aux functions ################################################ - def get_last_var(self): - return "var_" + str(self.var_cnt) - - def get_new_var(self): - self.var_cnt = self.var_cnt + 1 - return "var_" + str(self.var_cnt) + def _allocate_varname(self) -> str: + varname = f"var_{self.var_count}" + self.var_count += 1 + return varname ################################################ # CodeGen functions ################################################ - def emit_weights(self): - weights_str = "" - weights_str += "std::string dir_prefix = std::string(\"" + str(self.weights_dir) + "\");\n" - weights_str += "std::string input_path = dir_prefix + std::string(\"input.bin\");\n" - weights_str += "std::string labels_path = dir_prefix + std::string(\"labels.bin\");\n" - for tensor in self.tensors.values(): - if isinstance(tensor, WeightTensor): - weights_str += self.emit_single_weight(tensor) - self.program_str += weights_str - - def emit_single_weight(self, tensor): - N = tensor.shape[0] - C = tensor.shape[1] - H = tensor.shape[2] - W = tensor.shape[3] - mapped_name = tensor.get_mapped_name() - file_path = mapped_name + "_path" - unique_file_name = file_path + ".bin" - weight_str = "std::string " + file_path + " = " + " dir_prefix + std::string(\"" - weight_str += unique_file_name + "\"); \n" - weight_str += "void* " + mapped_name + " = " + " readTrainedWeights(" - weight_str += file_path + ".c_str(), 0," + str(N) + "," + str(C) + "," + str(H) + "," + str(W) - weight_str += "); \n" - return weight_str - - def emit_graph(self): - for node in self.nodes: - # check if all inputs of this node is mapped - for i in node.input: - self.tensors[i].get_mapped_name() - # set var name for output node - if len(node.output) > 1: - raise ValueError("Output number for a single layer larger than 1!") - if node.op_type in skip_layer or node.op_type == "BiasAdd": - mapped_output_name = self.get_last_var() - else: - mapped_output_name = self.get_new_var() - self.tensors[node.output[0]].set_mapped_name(mapped_output_name) - self.program_str += node.codegen(self.tensors) - - def emit_header(self): - headers = "\n#include <stdio.h> \n" - headers += "#include <stdlib.h> \n" - headers += "#include <unistd.h> \n" - headers += "#include <fcntl.h> \n" - headers += "#include <sys/types.h> \n" - headers += "#include <sys/stat.h> \n" - headers += "#include <string.h> \n" - headers += "#include <chrono> \n" - headers += "#include \"../../tensor_runtime/include/tensor_runtime.h\" \n" - headers += "#include \"../include/utils.h\" \n\n" - - main_func = "int main(){ \n\n" - initialization = "llvm_hpvm_initTensorRt(0); \n\n" - self.program_str += headers - self.program_str += main_func - self.program_str += initialization - - def emit_footer(self, test_data=None): - destructors = "\nllvm_hpvm_cleanupTensorRt(); \n" - end_main = "\nreturn 0; \n\n}\n" - self.program_str += destructors - self.program_str += end_main - - def emit_batch_loop(self, test_data_shape): - N = test_data_shape[0] - C = test_data_shape[1] - H = test_data_shape[2] - W = test_data_shape[3] - - loop_str = "" - loop_str += "\nstartMemTracking(); \n\n" - - loop_str += "int test_input_size = " + str(N) + "; \n" - loop_str += "int batch_size = " + str(N) + "; \n" - # FIXME: Ceiling for batch_count - loop_str += "int batch_count = test_input_size / batch_size; \n" - loop_str += "float final_accuracy = 0.0; \n\n" - - loop_str += "for(int i = 0; i < batch_count; i++){ \n\n" - loop_str += "int start = i * batch_size; \n" - loop_str += "int end = (i + 1) * batch_size; \n" - - loop_str += "\nvoid* input = readInputBatch(input_path.c_str(),0,start,end," - loop_str += str(C) + "," + str(H) + "," + str(W) + "); \n\n" - - self.program_str += loop_str - - def emit_batch_loop_end(self): - end_loop_str = "" - end_loop_str += "\nuint32_t* labels = readLabelsBatch3(labels_path.c_str(),start,end); \n" - mapped_output_var = self.tensors[self.graph.output[0].name].get_mapped_name() - accuracy_call = "\nfloat accuracy = computeAccuracy3(labels, " + \ - mapped_output_var + "); \n" - end_loop_str += accuracy_call - end_loop_str += "final_accuracy += accuracy; \n" - end_loop_str += "freeBatchMemory(); \n " - end_loop_str += "\n}\n\n" - - end_loop_str += "final_accuracy = final_accuracy / batch_count; \n" - end_loop_str += "dumpFinalAccuracy(final_accuracy); \n\n" - - self.program_str += end_loop_str - def emit_source_to_file(self, src_dir): - f = open(src_dir + "/src.cc", "w+") - f.write(self.program_str) - f.close() + def emit_graph(self) -> List[dict]: + graph_code = [] + for node in self.dfg.traverse_order: + func_name, extra_args = node.codegen() + if func_name == "": # No code generation + # Node must have single input, we equate the output to + # the input and skip code generation. + assert len(node.input) == 1 and len(node.output) == 1 + self.variables[node.output[0]] = self.variables[node.input[0]] + continue + input_args = [self.variables[arg] for arg in node.input] + extra_args + varname = self._allocate_varname() + self.variables[node.output[0]] = varname + graph_code.append( + {"output": varname, "inputs": input_args, "function": func_name} + ) + return graph_code ################################################ # Compile is a top level function to compile an onnx model into C/C++ @@ -143,13 +65,19 @@ class TensorCodeGen(object): ################################################ def compile(self): - #if os.path.exists(self.weights_dir): - # raise ValueError("Weight dir existed. Compilation interrupted!") - #os.mkdir(self.weights_dir) - self.emit_header() - self.emit_weights() - self.emit_batch_loop(self.test_data_shape) - self.emit_graph() - self.emit_batch_loop_end() - self.emit_footer() - self.emit_source_to_file(self.weights_dir) + # nodes = self.emit_hpvm_node_structures() + # inputs, output = self.emit_root_io() + graph_code = self.emit_graph() + input_arg = self.dfg.discover_input_var() + output_arg = self.variables[self.dfg.output.name] + with open(self.output_dir / "src.cc", "w") as f: + f.write( + template.render( + input=input_arg, + input_shape=self.test_data_shape, + output=output_arg, + graph_code=graph_code, + weights=emit_weights(self.tensors), + output_dir=self.output_dir, + ) + ) diff --git a/hpvm/projects/onnx/frontend/graph_builder.py b/hpvm/projects/onnx/frontend/graph_builder.py index f7260f2b2c1edfc1f88271e99386a1db6d20319d..542cfb535809f0334ce0a5c953023b933e4f1ecb 100644 --- a/hpvm/projects/onnx/frontend/graph_builder.py +++ b/hpvm/projects/onnx/frontend/graph_builder.py @@ -90,8 +90,8 @@ class DFG(object): def __init__(self, graph: GraphT, tensors: Dict[str, Tensor]): if len(graph.output) > 1: raise ValueError("Only single-output graph is supported") - self.inputs: List[str] = graph.input - self.output: str = graph.output[0] + self.inputs: List[NodeT] = graph.input + self.output: NodeT = graph.output[0] self._onnx_defs, self._onnx_uses = self.def_use(graph.node) self._var_count = 0 self.tensors = tensors @@ -101,6 +101,19 @@ class DFG(object): def traverse_order(self) -> List[g.DFGNode]: return list(nx.topological_sort(self.graph)) + def discover_input_var(self) -> Optional[str]: + """Guess which input tensor is the "input" to the ONNX model. + + This is useful when we batch through the input tensor. + It's a guess because sometimes ONNX model put everything in the inputs + (weights, etcs.), and there's no apparent way to tell which one is + the actual input (that we can batch over).""" + + if len(self.inputs) == 1: + return self.inputs[0].name + first_arg_first_node = self.traverse_order[0].input[0] + return first_arg_first_node + @staticmethod def def_use(nodes: list) -> Tuple[dict, dict]: """Computes def/use relation from a list of node. diff --git a/hpvm/projects/onnx/frontend/graph_ir.py b/hpvm/projects/onnx/frontend/graph_ir.py index 5c0a2eef3cfd70e78efd95da5a36dc8cde65d68e..467ed63ccb6f292c7c7f7bc9f8a1f1c9e3b8d167 100644 --- a/hpvm/projects/onnx/frontend/graph_ir.py +++ b/hpvm/projects/onnx/frontend/graph_ir.py @@ -10,8 +10,8 @@ class DFGNode(object): self.input = onnx_node.input self.output = onnx_node.output - def codegen(self, tensors): - return "\n***Not Implemented***\n" + def codegen(self): + return "", [] def hpvm_codegen(self): return "", [] @@ -45,14 +45,8 @@ class LogicalOpNode(DFGNode): class AddNode(DFGNode): - def codegen(self, tensors): - inst_str = "" - left_input = tensors[self.input[0]].get_mapped_name() - right_input = tensors[self.input[1]].get_mapped_name() - mapped_output_name = tensors[self.output[0]].get_mapped_name() - inst_str += "void* " + mapped_output_name + " = " - inst_str += "tensorAdd(" + left_input + ", " + right_input + "); \n" - return inst_str + def codegen(self): + return "tensorAdd", [] def hpvm_codegen(self): return "__hpvm__tensor_add", [] @@ -66,43 +60,23 @@ class BiasAddNode(DFGNode): self.input.append(self.output[0]) self.input.append(onnx_conv_node.input[2]) - def codegen(self, tensors): - inst_str = "" - left_input = tensors[self.input[0]].get_mapped_name() - right_input = tensors[self.input[1]].get_mapped_name() - mapped_output_name = tensors[self.output[0]].get_mapped_name() - inst_str += mapped_output_name + " = " - inst_str += "tensorAdd(" + left_input + ", " + right_input + "); \n" - return inst_str + def codegen(self): + return "tensorAdd", [] def hpvm_codegen(self): return "__hpvm__tensor_add", [] class MatMulNode(DFGNode): - - def codegen(self, tensors): - inst_str = "" - left_input = tensors[self.input[0]].get_mapped_name() - right_input = tensors[self.input[1]].get_mapped_name() - mapped_output_name = tensors[self.output[0]].get_mapped_name() - inst_str += "void* " + mapped_output_name + " = " - inst_str += "tensorGemmGPU(" + left_input + \ - ", " + right_input + "); \n" - return inst_str + def codegen(self): + return "tensorGemmGPU", [] def hpvm_codegen(self): return "__hpvm__tensor_mul", [] class SoftMaxNode(DFGNode): - - def codegen(self, tensors): - inst_str = "" - mapped_input_name = tensors[self.input[0]].get_mapped_name() - mapped_output_name = tensors[self.output[0]].get_mapped_name() - inst_str += "void* " + mapped_output_name + " = " - inst_str += "tensorSoftmax(" + mapped_input_name + "); \n" - return inst_str + def codegen(self): + return "tensorSoftmax", [] def hpvm_codegen(self): return "__hpvm__tensor_softmax", [] @@ -127,20 +101,8 @@ class Conv2DNode(DFGNode): for stride in attr.ints: self.strides.append(stride) - def codegen(self, tensors): - inst_str = "" - mapped_input_name = tensors[self.input[0]].get_mapped_name() - mapped_output_name = tensors[self.output[0]].get_mapped_name() - - inst_str += "void* " + mapped_output_name + " = " - inst_str += "tensorConvolution(" + mapped_input_name + ", " - inst_str += tensors[self.input[1]].get_mapped_name() + ", " - inst_str += str(self.padding) + ", " - inst_str += str(self.padding) + ", " - inst_str += str(self.strides[0]) + ", " - inst_str += str(self.strides[1]) + ", " - inst_str += "1, 1); \n" - return inst_str + def codegen(self): + return "tensorConvolution", [self.padding, self.padding, self.strides[0], self.strides[1]] def hpvm_codegen(self): return "__hpvm__tensor_convolution", [self.padding, self.padding, self.strides[0], self.strides[1]] @@ -162,17 +124,8 @@ class MaxPool2DNode(DFGNode): self.strides.append(stride) - def codegen(self, tensors): - mapped_input_name = tensors[self.input[0]].get_mapped_name() - mapped_output_name = tensors[self.output[0]].get_mapped_name() - # tensorPooling(input, pool_type, pool_h, pool_w, v_pad, h_pad, v_stride, h_stride) - inst_str = "void* " + mapped_output_name + " = " - inst_str += "tensorPooling(" + mapped_input_name + "," + \ - self.pool_type + "," + str(self.pool_size[0]) + "," + str(self.pool_size[1]) - inst_str += "," + str(self.padding) + "," + str(self.padding) + \ - "," + str(self.strides[0]) + "," + str(self.strides[1]) - inst_str += "); \n" - return inst_str + def codegen(self): + return "tensorPooling", [self.pool_type, *self.pool_size, self.padding, self.padding, *self.strides] def hpvm_codegen(self): return "__hpvm__tensor_pool_max", [*self.pool_size, self.padding, self.padding, *self.strides] @@ -194,16 +147,8 @@ class AveragePool2DNode(DFGNode): for stride in attr.ints: self.strides.append(stride) - def codegen(self, tensors): - mapped_input_name = tensors[self.input[0]].get_mapped_name() - mapped_output_name = tensors[self.output[0]].get_mapped_name() - inst_str = "void* " + mapped_output_name + " = " - inst_str += "tensorPooling(" + mapped_input_name + "," + \ - self.pool_type + "," + str(self.pool_size[0]) + "," + str(self.pool_size[1]) - inst_str += "," + str(self.padding) + "," + str(self.padding) + \ - "," + str(self.strides[0]) + "," + str(self.strides[1]) - inst_str += "); \n" - return inst_str + def codegen(self): + return "tensorPooling", [self.pool_type, *self.pool_size, self.padding, self.padding, *self.strides] def hpvm_codegen(self): return "__hpvm__tensor_pool_avg", [*self.pool_size, self.padding, self.padding, *self.strides] @@ -211,24 +156,16 @@ class AveragePool2DNode(DFGNode): class ReluNode(DFGNode): - def codegen(self, tensors): - mapped_input_name = tensors[self.input[0]].get_mapped_name() - mapped_output_name = tensors[self.output[0]].get_mapped_name() - inst_str = "void* " + mapped_output_name + " = " - inst_str += "tensorRelu(" + mapped_input_name + "); \n" - return inst_str + def codegen(self): + return "tensorRelu", [] def hpvm_codegen(self): return "__hpvm__tensor_relu", [] class TanhNode(DFGNode): - def codegen(self, tensors): - mapped_input_name = tensors[self.input[0]].get_mapped_name() - mapped_output_name = tensors[self.output[0]].get_mapped_name() - inst_str = "void* " + mapped_output_name + " = " - inst_str += "tensorTanh(" + mapped_input_name + "); \n" - return inst_str + def codegen(self): + return "tensorTanh", [] def hpvm_codegen(self): return "__hpvm__tensor_tanh", [] @@ -243,31 +180,13 @@ class BatchNormalizationNode(DFGNode): if attr.name == "epsilon": self.epsilon = str(attr.f) - def codegen(self, tensors): - mapped_input_name = tensors[self.input[0]].get_mapped_name() - mapped_output_name = tensors[self.output[0]].get_mapped_name() - inst_str = "void* " + mapped_output_name + " = " - inst_str += "tensorBatchNorm(" + mapped_input_name + ", " - inst_str += tensors[self.input[1]].get_mapped_name() + ", " - inst_str += tensors[self.input[2]].get_mapped_name() + ", " - inst_str += tensors[self.input[3]].get_mapped_name() + ", " - inst_str += tensors[self.input[4]].get_mapped_name() + ", " - inst_str += str(self.epsilon) - inst_str += "); \n" - return inst_str + def codegen(self): + return "tensorBatchNorm", [self.epsilon] def hpvm_codegen(self): return "__hpvm__tensor_batchnorm", [self.epsilon] -class PadNode(DFGNode): - def codegen(self, tensors): - return "" - -class IdentityNode(DFGNode): - def codegen(self, tensors): - return "" - class FlattenNode(DFGNode): def __init__(self, name: str, op_type: str, input, output): self.name = name @@ -284,9 +203,6 @@ class FlattenNode(DFGNode): _, suffix = nodes[0].name.split('_') return cls(f'Flatten_{suffix}', 'Flatten', nodes[0].input, nodes[-1].output) - def codegen(self, tensors): - return "" - class ZeroPadding2DNode(DFGNode): pass @@ -295,3 +211,10 @@ class DepthwiseConv2DNode(DFGNode): class DenseNode(DFGNode): pass + + +class PadNode(DFGNode): + pass + +class IdentityNode(DFGNode): + pass \ No newline at end of file diff --git a/hpvm/projects/onnx/frontend/template_tensor.cpp b/hpvm/projects/onnx/frontend/template_tensor.cpp new file mode 100644 index 0000000000000000000000000000000000000000..ae3060836451069ed6aee9321d714db1b66c1723 --- /dev/null +++ b/hpvm/projects/onnx/frontend/template_tensor.cpp @@ -0,0 +1,50 @@ +#include <chrono> +#include <fcntl.h> +#include <stdio.h> +#include <stdlib.h> +#include <string.h> +#include <sys/stat.h> +#include <sys/types.h> +#include <unistd.h> + +#include "../../tensor_runtime/include/tensor_runtime.h" +#include "../include/utils.h" + +int main() { + std::string dir_prefix = "{{output_dir}}"; + std::string input_path = dir_prefix + "input.bin"; + std::string labels_path = dir_prefix + "labels.bin"; +{% for w in weights %} + std::string {{w.name}}_path = dir_prefix + std::string("{{w.filename}}"); + void* {{w.name}} = readTrainedWeights({{w.name}}_path.c_str(), 0, {{w.shape|join(', ')}}); +{% endfor %} + + llvm_hpvm_initTensorRt(0); + startMemTracking(); + + int test_input_size = {{input_shape[0]}}; + int batch_size = {{input_shape[0]}}; + // # FIXME: Ceiling for batch_count + int batch_count = test_input_size / batch_size; + float final_accuracy = 0.0; + + for (int i = 0; i < batch_count; i++) { + int start = i * batch_size; + int end = (i + 1) * batch_size; + void *{{input}} = readInputBatch(input_path.c_str(), 0, start, end, {{input_shape|join(', ')}}); + +{% for code in graph_code %} + auto {{code.output}} = {{code.function}}({{code.inputs|join(', ')}}); +{% endfor %} + + uint32_t* labels = readLabelsBatch3(labels_path.c_str(), start, end); + float accuracy = computeAccuracy3(labels, {{output}}); + final_accuracy += accuracy; + freeBatchMemory(); + } + + final_accuracy = final_accuracy / batch_count; + dumpFinalAccuracy(final_accuracy); + llvm_hpvm_cleanupTensorRt(); + return 0; +} \ No newline at end of file