Skip to content
Snippets Groups Projects
graph_ir.py 6.38 KiB
################################################
# Top Level DFGNode interface
################################################


from typing import List
import onnx


class DFGNode:
    def __init__(self, onnx_node: onnx.NodeProto):
        self.name = onnx_node.name
        self.op_type = onnx_node.op_type
        self.input = onnx_node.input
        self.output = onnx_node.output

    def codegen(self):
        return "", []

    def hpvm_codegen(self):
        return "", []

    def __repr__(self):
        return f"{self.__class__.__name__}({self.input}) -> {self.output}"


################################################
# Actual Implementation of Operators
################################################


class AddNode(DFGNode):
    def codegen(self):
        return "tensorAdd", []

    def hpvm_codegen(self):
        return "__visc__tensor_add", []


class BiasAddNode(DFGNode):
    def __init__(self, onnx_conv_node: onnx.NodeProto):
        super().__init__(onnx_conv_node)
        self.op_type = "BiasAdd"
        self.input = list()
        self.input.append(self.output[0])
        self.input.append(onnx_conv_node.input[2])

    def codegen(self):
        return "tensorAdd", []

    def hpvm_codegen(self):
        return "__visc__tensor_add", []


class MatMulNode(DFGNode):
    def codegen(self):
        return "tensorGemmGPU", []

    def hpvm_codegen(self):
        return "__visc__tensor_mul", []


class SoftMaxNode(DFGNode):
    def codegen(self):
        return "tensorSoftmax", []

    def hpvm_codegen(self):
        return "__visc__tensor_softmax", []


class Conv2DNode(DFGNode):
    def __init__(self, onnx_node: onnx.NodeProto):
        super().__init__(onnx_node)
        if len(self.input) == 3:
            tmp_input = list()
            for i in self.input:
                tmp_input.append(i)
            self.input = tmp_input
            self.input.pop()  # remove the last index for bias add
        self.padding = 0
        self.strides = list()
        for attr in onnx_node.attribute:
            if attr.name == "pads":
                self.padding = attr.ints[0]
            elif attr.name == "strides":
                for stride in attr.ints:
                    self.strides.append(stride)

    def codegen(self):
        return (
            "tensorConvolution",
            [self.padding, self.padding, self.strides[0], self.strides[1]],
        )

    def hpvm_codegen(self):
        return (
            "__visc__tensor_convolution",
            [self.padding, self.padding, self.strides[0], self.strides[1]],
        )


class MaxPool2DNode(DFGNode):
    def __init__(self, onnx_node: onnx.NodeProto):
        super().__init__(onnx_node)
        self.strides = list()
        self.pool_size = list()
        self.padding = 0
        self.pool_type = "0"
        for attr in onnx_node.attribute:
            if attr.name == "kernel_shape":
                for pool in attr.ints:
                    self.pool_size.append(pool)
            elif attr.name == "strides":
                for stride in attr.ints:
                    self.strides.append(stride)

    def codegen(self):
        return (
            "tensorPooling",
            [
                self.pool_type,
                *self.pool_size,
                self.padding,
                self.padding,
                *self.strides,
            ],
        )

    def hpvm_codegen(self):
        return (
            "__visc__tensor_pool_max",
            [*self.pool_size, self.padding, self.padding, *self.strides],
        )


class AveragePool2DNode(DFGNode):
    def __init__(self, onnx_node: onnx.NodeProto):
        super().__init__(onnx_node)
        self.strides = list()
        self.pool_size = list()
        self.padding = 0
        self.pool_type = "1"
        for attr in onnx_node.attribute:
            if attr.name == "kernel_shape":
                for pool in attr.ints:
                    self.pool_size.append(pool)
            elif attr.name == "strides":
                for stride in attr.ints:
                    self.strides.append(stride)

    def codegen(self):
        return (
            "tensorPooling",
            [
                self.pool_type,
                *self.pool_size,
                self.padding,
                self.padding,
                *self.strides,
            ],
        )

    def hpvm_codegen(self):
        return (
            "__visc__tensor_pool_avg",
            [*self.pool_size, self.padding, self.padding, *self.strides],
        )


class ReluNode(DFGNode):
    def codegen(self):
        return "tensorRelu", []

    def hpvm_codegen(self):
        return "__visc__tensor_relu", []


class TanhNode(DFGNode):
    def codegen(self):
        return "tensorTanh", []

    def hpvm_codegen(self):
        return "__visc__tensor_tanh", []


class BatchNormalizationNode(DFGNode):
    def __init__(self, onnx_node: onnx.NodeProto):
        super().__init__(onnx_node)
        self.epsilon = ""
        for attr in onnx_node.attribute:
            if attr.name == "epsilon":
                self.epsilon = str(attr.f)

    def codegen(self):
        return "tensorBatchNorm", [self.epsilon]

    def hpvm_codegen(self):
        return "__visc__tensor_batchnorm", [self.epsilon]


class FlattenNode(DFGNode):
    def __init__(self, name: str, op_type: str, input, output):
        self.name = name
        self.op_type = op_type
        self.input = input
        self.output = output

    @classmethod
    def from_single_node(cls, n: onnx.NodeProto):
        return cls(n.name, n.op_type, n.input, n.output)

    @classmethod
    def from_onnx_idiom(cls, nodes: List[onnx.NodeProto]):
        _, suffix = nodes[0].name.split("_")
        return cls(f"Flatten_{suffix}", "Flatten", nodes[0].input, nodes[-1].output)


class ActivationNode(DFGNode):
    """
    Element wise operators that is for activation function
    e.g. HardSigmoid, LeakyRelu, PRelu, Pow, Reciprocal,
    Relu, Selu, Sigmoid, Softplus, Sqrt, ThresholdedRelu,
    Abs, Ceil, Elu, Floor, Neg
    """

    pass


class LogicalOpNode(DFGNode):
    """
    ELement wise operators that is not for activation function.
    In other words, they are logical comparison operators
    e.g. And, Equal, Greater, GreaterOrEqual, Less, LessOrEqual,
    Or, Xor
    """

    pass


class ZeroPadding2DNode(DFGNode):
    pass


class DepthwiseConv2DNode(DFGNode):
    pass


class DenseNode(DFGNode):
    pass


class PadNode(DFGNode):
    pass


class IdentityNode(DFGNode):
    pass