From 76e3aa8bdb636b14cc96e51ec9d1084ef7dfcaef Mon Sep 17 00:00:00 2001 From: Yifan Zhao <yifanz16@illinois.edu> Date: Sat, 3 Apr 2021 13:17:38 -0500 Subject: [PATCH] Fixed max pooling / mean pooling confusion --- .../projects/torch2hpvm/torch2hpvm/graph_ir.py | 18 ++++++++++++------ 1 file changed, 12 insertions(+), 6 deletions(-) diff --git a/hpvm/projects/torch2hpvm/torch2hpvm/graph_ir.py b/hpvm/projects/torch2hpvm/torch2hpvm/graph_ir.py index a088e6eae5..5c248f829a 100644 --- a/hpvm/projects/torch2hpvm/torch2hpvm/graph_ir.py +++ b/hpvm/projects/torch2hpvm/torch2hpvm/graph_ir.py @@ -198,12 +198,6 @@ class _Pool2DNode(DFGNode, abc.ABC): [self.pool_type, *self.kernel_shape, *self.pads, *self.strides,], ) - def hpvm_codegen(self): - return ( - "__hpvm__tensor_pool_max", - [*self.kernel_shape, *self.pads, *self.strides], - ) - def get_flops(self) -> int: input0 = self.input_shapes[0] return np.prod(input0) if input0 else 0 @@ -214,12 +208,24 @@ class MaxPool2DNode(_Pool2DNode): op_type = "MaxPool2D" hpvm_op_type = "maxpool" + def hpvm_codegen(self): + return ( + "__hpvm__tensor_pool_max", + [*self.kernel_shape, *self.pads, *self.strides], + ) + class AveragePool2DNode(_Pool2DNode): pool_type = "1" op_type = "AveragePool2D" hpvm_op_type = "avgpool" + def hpvm_codegen(self): + return ( + "__hpvm__tensor_pool_mean", + [*self.kernel_shape, *self.pads, *self.strides], + ) + class BiasAddNode(DFGNode): op_type = "BiasAdd" -- GitLab