diff --git a/hpvm/projects/torch2hpvm/torch2hpvm/graph_ir.py b/hpvm/projects/torch2hpvm/torch2hpvm/graph_ir.py index a088e6eae5c7cd8fb3db62f5046aa5d9ac945726..5c248f829adef15093b853891927f353aca30c4b 100644 --- a/hpvm/projects/torch2hpvm/torch2hpvm/graph_ir.py +++ b/hpvm/projects/torch2hpvm/torch2hpvm/graph_ir.py @@ -198,12 +198,6 @@ class _Pool2DNode(DFGNode, abc.ABC): [self.pool_type, *self.kernel_shape, *self.pads, *self.strides,], ) - def hpvm_codegen(self): - return ( - "__hpvm__tensor_pool_max", - [*self.kernel_shape, *self.pads, *self.strides], - ) - def get_flops(self) -> int: input0 = self.input_shapes[0] return np.prod(input0) if input0 else 0 @@ -214,12 +208,24 @@ class MaxPool2DNode(_Pool2DNode): op_type = "MaxPool2D" hpvm_op_type = "maxpool" + def hpvm_codegen(self): + return ( + "__hpvm__tensor_pool_max", + [*self.kernel_shape, *self.pads, *self.strides], + ) + class AveragePool2DNode(_Pool2DNode): pool_type = "1" op_type = "AveragePool2D" hpvm_op_type = "avgpool" + def hpvm_codegen(self): + return ( + "__hpvm__tensor_pool_mean", + [*self.kernel_shape, *self.pads, *self.strides], + ) + class BiasAddNode(DFGNode): op_type = "BiasAdd"