Skip to content
Snippets Groups Projects
graph_builder.py 9.47 KiB
import sys
from onnx import numpy_helper
from graph_ir import Node
from common import InputTensor, WeightTensor

support_onnx_ops = {"DepthwiseConv" : [2],
               "Conv" : [2], # only 2d supported here
               "MatMul" : None,
               "MaxPool": [2], # only 2d supported here
               "Activation" : None,
               "BatchNormalization" : None,
               "Flatten" : None,
               "Add" : None,
               "Relu" : None,
               "Softmax" : None,
               "Identity": None,
               "Pad": None,
               "AveragePool": None,
               "Tanh": None}

class GraphBuilder(object):
    def __init__(self, model, shape, dtype, weight_dir):
        self._check_model(model)
        self._check_ops(model)
        self.model = model
        self.dtype = dtype
        self.graph = model.graph
        self.weight_dir = weight_dir
        self.shape = shape if shape else self._build_shape()
        self.tensors = dict()

    ################################################
    # Aux functions for graph building
    ################################################

    def _check_model(self, onnx_model):
        try:
            from onnx import checker, onnx_cpp2py_export
            if hasattr(checker, 'check_model'):
                # try use onnx's own model checker before converting any model
                try:
                    checker.check_model(onnx_model)
                    print("onnx model is checked valid!")
                except onnx_cpp2py_export.checker.ValidationError as e:
                    import warnings
                    warnings.warn(str(e))
        except ImportError as e:
            raise ImportError(
                "Unable to import onnx.checker which is required {}".format(e))

    def _check_ops(self, model):
        unsupport = dict()
        for node in model.graph.node:
            if node.op_type not in support_onnx_ops:
                if node.op_type not in unsupport:
                    unsupport[node.op_type] = 1
                else:
                    unsupport[node.op_type] += 1
        if len(unsupport) != 0:
            print(sorted(unsupport.items(), key=lambda x: x[1], reverse=True))
            raise ValueError(
                "Above operator(s) not currently supported! Compilation Aborted.")

    def _build_shape(self):
        shape = {}
        for input in self.graph.input:
            # get type of input tensor
            tensor_type = input.type.tensor_type
            # check if it has a shape:
            if (tensor_type.HasField("shape")):
                shape[input.name] = tensor_type.shape
        return shape

    def _parse_array(self, tensor_proto):
        try:
            from onnx.numpy_helper import to_array
        except ImportError as e:
            raise ImportError(
                "Unable to import onnx which is required {}".format(e))
        np_array = to_array(tensor_proto).reshape(tuple(tensor_proto.dims))
        return np_array

    def _parse_value_proto(self, value_proto):
        """Parse ValueProto or raw str."""
        try:
            name = value_proto.name
        except AttributeError:
            name = value_proto
        return name

    def _parse_dtype(self, value_proto, dtype):
        """Parse dtype."""
        try:
            from onnx.mapping import TENSOR_TYPE_TO_NP_TYPE
            return TENSOR_TYPE_TO_NP_TYPE[value_proto.type.tensor_type.elem_type].name
        except AttributeError:
            return dtype

    def _support_check(self, node):
        op_name = node.op_type
        #print(op_name)
        if op_name not in support_onnx_ops:
            return False
        else:
            if support_onnx_ops[op_name] == None:
                return True
            else:
                #print(type(node.attribute))
                for attr in node.attribute:
                    # partially evaluate the kernel shape
                    if attr.name == 'kernel_shape':
                        # TODO: not assume all kernel shape is in INTS
                        return len(attr.ints) in support_onnx_ops[op_name]
                return False
               
    def _dump_weight(self, weight_tensor):
        print("Dump weight: {0}".format(weight_tensor.name))


    ################################################
    # Top level Graph Building functions
    # return the compilation-ready graph
    ################################################

    def build_graph(self):
        # parse weight
        weight_cnt = 0
        for weight_tensor in self.graph.initializer:
            self.tensors[weight_tensor.name] = WeightTensor(weight_tensor)
            self.tensors[weight_tensor.name].set_mapped_name("weight_" + str(weight_cnt))
            weight_cnt += 1
        # parse input
        for i in self.graph.input:
            if i.name not in self.tensors:
                self.tensors[i.name] = InputTensor(i.name)
                # FIXME: This input name is hardcoded
                self.tensors[i.name].set_mapped_name("input")
        # parse intermediate tensor
        for node in self.graph.node:
            op_name = node.op_type
            #print("###############################")
            if not self._support_check(node):
                raise ValueError(
                        "Operator not currently supported: `{0}`!".format(op_name))
            #print("attribute: " + str(node.attribute))
            #print("input: " + str(node.input))
            #print("output: " + str(node.output))
            #print("###############################")
            for i in node.input:
                if i not in self.tensors:
                    raise ValueError(
                        "Compilation Interrupted for missing input!`{0}`.".format(i))
            for i in node.output:
                if i not in self.tensors:
                    self.tensors[i] = InputTensor(i)
        # Dump weights
        for tensor in self.tensors.values():
            if isinstance(tensor, WeightTensor):
                self._dump_weight(tensor)
        return DFG(self.graph, self.tensors)

class DFG(object):

    root_set = False

    def __init__(self, graph, tensors):
        self.graph = graph
        self.tensors = tensors

    def hasSingleInput(self, layer):
        layer_name = layer.__class__.__name__
        return layer_name in self.singleInputLayers

    def hasMultipleInputs(self, layer):
        layer_name = layer.__class__.__name__
        return layer_name in self.multiInputLayers

    def add_dfg_edge(self, inbound_node_name, dfg_node):
        inbound_node_name = inbound_node_name.split(":")[0]
        inbound_node_name = inbound_node_name.split("/")[0]
        if inbound_node_name in self.node_map:
            inbound_node = self.node_map[inbound_node_name]
            print(inbound_node_name, " found!")
            inbound_node.add_output(dfg_node)
            dfg_node.add_input(inbound_node)
        else:
            print("--inbound node NOT FOUND!")

    def add_to_graph(self, layer):
        dfg_node = DFGNode(layer)
        if not self.root_set:
            self.root_node = dfg_node
            self.root_set = True  # DFG root node is now set

        if self.hasMultipleInputs(layer):
            for j in range(len(layer.input)):
                print(type(layer.input[j]))
                print(layer.input[j].op.name)
                self.add_dfg_edge(layer.input[j].op.name, dfg_node)
        else:
            print(layer.input.name)
            self.add_dfg_edge(layer.input.name, dfg_node)
        # Adding DFG node to name mapping
        self.node_map[layer.name] = dfg_node

    # Check if all predecessor nodes have been visited thus far - reverse
    # postorder traversal

    def predVisited(self, cur_node, visited_nodes):
        for input_node in cur_node.inputs:
            if input_node.layer_name not in visited_nodes:
                return False
        # All predecessors are visited
        return True

    def traverseNode(self, cur_node, visited_nodes):
        # Skip visited nodes
        if cur_node.layer_name in visited_nodes:
            return

        if self.predVisited(cur_node, visited_nodes):
            print(cur_node.layer_type)
            print(cur_node.layer_name)
            visited_nodes[cur_node.layer_name] = True

            # Invoking traversal on outbound nodes
            for output_node in cur_node.outputs:
                self.traverseNode(output_node, visited_nodes)

            # NOTE: Assuming that no outbound edges implies the last node in
            # the graph
            if len(cur_node.outputs) == 0:
                self.last_node = cur_node

    # Build and  Print the DFG in reverse postorder

    def buildDFG(self):
        print("\n\n ****** Traversing and Printing DFG ******* \n\n")
        visited_nodes = {}
        # Starting traversal at the DFG root node
        self.traverseNode(self.root_node, visited_nodes)

    # This should be the place where partial evaluation happens
    def emitNode(self, layer):
        if layer.op_type == "Conv":
            return Conv2DNode()
        elif layer.op_type == "Tanh":
            pass
        elif layer.op_type == "MaxPool":
            pass
        elif layer.op_type == "Flatten":
            pass
        elif layer.op_type == "MatMul":
            pass
        elif layer.op_type == "Add":
            pass
        elif layer.op_type == "SoftMax":
            pass
        elif layer.op_type == "Identity":
            pass
        else:
            raise ValueError("Unsupported operator type!")
            sys.exit("Unsupported operator type!")