Skip to content
Snippets Groups Projects
Commit b6606ceb authored by Yifan Zhao's avatar Yifan Zhao
Browse files

Added support for GlobalAveragePool

parent c12a9b26
No related branches found
No related tags found
No related merge requests found
...@@ -183,7 +183,7 @@ class DFG(object): ...@@ -183,7 +183,7 @@ class DFG(object):
# Split into conv followed by an addition # Split into conv followed by an addition
bias_node = g.BiasAddNode(f"Bias_{node.name.split('_')[-1]}") bias_node = g.BiasAddNode(f"Bias_{node.name.split('_')[-1]}")
return MarkedSubGraph.idiomatic_1to2(conv_node, bias_node, predec) return MarkedSubGraph.idiomatic_1to2(conv_node, bias_node, predec)
elif node.op_type in ("MatMul", "Gemm"): if node.op_type in ("MatMul", "Gemm"):
mul_node = g.MatMulNode(node.name, **attrs) mul_node = g.MatMulNode(node.name, **attrs)
if node.op_type == "Gemm": if node.op_type == "Gemm":
mul_node.gemm_transpose(node, predec) mul_node.gemm_transpose(node, predec)
...@@ -192,6 +192,10 @@ class DFG(object): ...@@ -192,6 +192,10 @@ class DFG(object):
# Split into mul followed by an addition # Split into mul followed by an addition
bias_node = g.BiasAddNode(f"Bias_{node.name.split('_')[-1]}") bias_node = g.BiasAddNode(f"Bias_{node.name.split('_')[-1]}")
return MarkedSubGraph.idiomatic_1to2(mul_node, bias_node, predec) return MarkedSubGraph.idiomatic_1to2(mul_node, bias_node, predec)
if node.op_type == "GlobalAveragePool":
input0_shape = in_graph.nodes[predec[0]]["shape"]
_, _, h, w = input0_shape
return g.AveragePool2DNode(node.name, [1, 1], (h, w), [0, 0, 0, 0])
one_to_one_nodes = { one_to_one_nodes = {
"MaxPool": g.MaxPool2DNode, "MaxPool": g.MaxPool2DNode,
"AveragePool": g.AveragePool2DNode, "AveragePool": g.AveragePool2DNode,
......
import abc import abc
from os import PathLike from os import PathLike
from typing import List, Tuple from typing import List, Sequence, Tuple
import onnx import onnx
...@@ -102,28 +102,36 @@ class WeightTensor(TensorNode): ...@@ -102,28 +102,36 @@ class WeightTensor(TensorNode):
class Conv2DNode(DFGNode): class Conv2DNode(DFGNode):
op_type = "Conv2D" op_type = "Conv2D"
def __init__(self, name: str, pads, strides, dilations, group: int, kernel_shape): def __init__(
self,
name: str,
pads: Sequence[int],
strides: Sequence[int],
dilations: Sequence[int],
group: int,
kernel_shape,
):
super().__init__(name) super().__init__(name)
assert len(pads) == 4, "2D convolution must have 4 padding values" assert len(pads) == 4, "2D convolution must have 4 padding values"
if any(p != pads[0] for p in pads[1:]): if any(p != pads[0] for p in pads[1:]):
raise ValueError("Convolution with different padding is unsupported") raise ValueError("Convolution with different padding is unsupported")
if dilations != [1, 1]: if list(dilations) != [1, 1]:
raise ValueError("Dilation > 1 is unsupported") raise ValueError("Dilation > 1 is unsupported")
if group != 1: if group != 1:
raise ValueError("Group > 1 is unsupported") raise ValueError("Group > 1 is unsupported")
self.pads = pads[0] self.pads = pads[0]
self.strides = strides self.strides = sh, sw = strides
def codegen(self): def codegen(self):
return ( return (
"tensorConvolution", "tensorConvolution",
[self.pads, self.pads, self.strides[0], self.strides[1]], [self.pads, self.pads, *self.strides],
) )
def hpvm_codegen(self): def hpvm_codegen(self):
return ( return (
"__visc__tensor_convolution", "__visc__tensor_convolution",
[self.pads, self.pads, self.strides[0], self.strides[1]], [self.pads, self.pads, *self.strides],
) )
...@@ -132,13 +140,22 @@ class _Pool2DNode(DFGNode, abc.ABC): ...@@ -132,13 +140,22 @@ class _Pool2DNode(DFGNode, abc.ABC):
pool_type = "0" pool_type = "0"
def __init__(self, name: str, strides, kernel_shape, pads, ceil_mode: int): def __init__(
self,
name: str,
strides: Sequence[int],
kernel_shape: Sequence[int],
pads: Sequence[int],
ceil_mode: int = 0,
):
super().__init__(name) super().__init__(name)
self.strides = strides self.strides = sh, sw = strides
self.kernel_shape = kernel_shape self.kernel_shape = kernel_shape
pt, pb, pl, pr = pads pt, pb, pl, pr = pads
if pt != pb or pl != pr: if pt != pb or pl != pr:
raise ValueError("Unequal padding on either side of same axis is unsupported") raise ValueError(
"Unequal padding on either side of same axis is unsupported"
)
self.pads = pt, pl self.pads = pt, pl
if ceil_mode != 0: if ceil_mode != 0:
raise ValueError("Only ceil_mode == 0 is supported") raise ValueError("Only ceil_mode == 0 is supported")
...@@ -146,12 +163,7 @@ class _Pool2DNode(DFGNode, abc.ABC): ...@@ -146,12 +163,7 @@ class _Pool2DNode(DFGNode, abc.ABC):
def codegen(self): def codegen(self):
return ( return (
"tensorPooling", "tensorPooling",
[ [self.pool_type, *self.kernel_shape, *self.pads, *self.strides,],
self.pool_type,
*self.kernel_shape,
*self.pads,
*self.strides,
],
) )
def hpvm_codegen(self): def hpvm_codegen(self):
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment