code
import torch.nn as nn
#input ->N * C * H * W
class HorizontalMaxPool2d(nn.Module):
def __init__(self):
super(HorizontalMaxPool2d, self).__init__()
def forward(self, x):
inp_size = x.size()
return nn.functional.max_pool2d(input=x,kernel_size= (1, inp_size[3]))
if __name__ == '__main__':
import torch
x = torch.Tensor(32, 2048, 8, 4)#注意tensor 和 Tensor的区别,tensor将numpy转换成
print(x.size())
hp = HorizontalMaxPool2d()
y = hp(x)
print(y.size())
torch.Size([32, 2048, 8, 4])
torch.Size([32, 2048, 8, 1])