Skip to content
Snippets Groups Projects
Commit 76e3aa8b authored by Yifan Zhao's avatar Yifan Zhao
Browse files

Fixed max pooling / mean pooling confusion

parent e3cd6837
No related branches found
No related tags found
No related merge requests found
......@@ -198,12 +198,6 @@ class _Pool2DNode(DFGNode, abc.ABC):
[self.pool_type, *self.kernel_shape, *self.pads, *self.strides,],
)
def hpvm_codegen(self):
return (
"__hpvm__tensor_pool_max",
[*self.kernel_shape, *self.pads, *self.strides],
)
def get_flops(self) -> int:
input0 = self.input_shapes[0]
return np.prod(input0) if input0 else 0
......@@ -214,12 +208,24 @@ class MaxPool2DNode(_Pool2DNode):
op_type = "MaxPool2D"
hpvm_op_type = "maxpool"
def hpvm_codegen(self):
return (
"__hpvm__tensor_pool_max",
[*self.kernel_shape, *self.pads, *self.strides],
)
class AveragePool2DNode(_Pool2DNode):
pool_type = "1"
op_type = "AveragePool2D"
hpvm_op_type = "avgpool"
def hpvm_codegen(self):
return (
"__hpvm__tensor_pool_mean",
[*self.kernel_shape, *self.pads, *self.strides],
)
class BiasAddNode(DFGNode):
op_type = "BiasAdd"
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment