Skip to content
Snippets Groups Projects
Commit b6606ceb authored by Yifan Zhao's avatar Yifan Zhao
Browse files

Added support for GlobalAveragePool

parent c12a9b26
No related branches found
No related tags found
No related merge requests found
......@@ -183,7 +183,7 @@ class DFG(object):
# Split into conv followed by an addition
bias_node = g.BiasAddNode(f"Bias_{node.name.split('_')[-1]}")
return MarkedSubGraph.idiomatic_1to2(conv_node, bias_node, predec)
elif node.op_type in ("MatMul", "Gemm"):
if node.op_type in ("MatMul", "Gemm"):
mul_node = g.MatMulNode(node.name, **attrs)
if node.op_type == "Gemm":
mul_node.gemm_transpose(node, predec)
......@@ -192,6 +192,10 @@ class DFG(object):
# Split into mul followed by an addition
bias_node = g.BiasAddNode(f"Bias_{node.name.split('_')[-1]}")
return MarkedSubGraph.idiomatic_1to2(mul_node, bias_node, predec)
if node.op_type == "GlobalAveragePool":
input0_shape = in_graph.nodes[predec[0]]["shape"]
_, _, h, w = input0_shape
return g.AveragePool2DNode(node.name, [1, 1], (h, w), [0, 0, 0, 0])
one_to_one_nodes = {
"MaxPool": g.MaxPool2DNode,
"AveragePool": g.AveragePool2DNode,
......
import abc
from os import PathLike
from typing import List, Tuple
from typing import List, Sequence, Tuple
import onnx
......@@ -102,28 +102,36 @@ class WeightTensor(TensorNode):
class Conv2DNode(DFGNode):
op_type = "Conv2D"
def __init__(self, name: str, pads, strides, dilations, group: int, kernel_shape):
def __init__(
self,
name: str,
pads: Sequence[int],
strides: Sequence[int],
dilations: Sequence[int],
group: int,
kernel_shape,
):
super().__init__(name)
assert len(pads) == 4, "2D convolution must have 4 padding values"
if any(p != pads[0] for p in pads[1:]):
raise ValueError("Convolution with different padding is unsupported")
if dilations != [1, 1]:
if list(dilations) != [1, 1]:
raise ValueError("Dilation > 1 is unsupported")
if group != 1:
raise ValueError("Group > 1 is unsupported")
self.pads = pads[0]
self.strides = strides
self.strides = sh, sw = strides
def codegen(self):
return (
"tensorConvolution",
[self.pads, self.pads, self.strides[0], self.strides[1]],
[self.pads, self.pads, *self.strides],
)
def hpvm_codegen(self):
return (
"__visc__tensor_convolution",
[self.pads, self.pads, self.strides[0], self.strides[1]],
[self.pads, self.pads, *self.strides],
)
......@@ -132,13 +140,22 @@ class _Pool2DNode(DFGNode, abc.ABC):
pool_type = "0"
def __init__(self, name: str, strides, kernel_shape, pads, ceil_mode: int):
def __init__(
self,
name: str,
strides: Sequence[int],
kernel_shape: Sequence[int],
pads: Sequence[int],
ceil_mode: int = 0,
):
super().__init__(name)
self.strides = strides
self.strides = sh, sw = strides
self.kernel_shape = kernel_shape
pt, pb, pl, pr = pads
if pt != pb or pl != pr:
raise ValueError("Unequal padding on either side of same axis is unsupported")
raise ValueError(
"Unequal padding on either side of same axis is unsupported"
)
self.pads = pt, pl
if ceil_mode != 0:
raise ValueError("Only ceil_mode == 0 is supported")
......@@ -146,12 +163,7 @@ class _Pool2DNode(DFGNode, abc.ABC):
def codegen(self):
return (
"tensorPooling",
[
self.pool_type,
*self.kernel_shape,
*self.pads,
*self.strides,
],
[self.pool_type, *self.kernel_shape, *self.pads, *self.strides,],
)
def hpvm_codegen(self):
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment