import torch
import torch.nn as nn
class Channel_Max_Pooling(nn.Module):
def __init__(self, kernel_size, stride):
super(Channel_Max_Pooling, self).__init__()
self.max_pooling = nn.MaxPool2d(
kernel_size=kernel_size,
stride=stride
)
def forward(self, x):
print('Input_Shape:', x.shape) # (batch_size, chs, h, w)
x = x.transpose(1, 3) # (batch_size, w, h, chs)
print('Transpose_Shape:', x.shape)
x = self.max_pooling(x)
print('Transpose_MaxPooling_Shape:', x.shape)
out = x.transpose(1, 3)
print('Output_Shape:', out.shape)
return out
cmp = Channel_Max_Pooling((1, 2), (1, 2))
tensor = torch.randn((3, 6, 4, 4))
out = cmp(tensor)
print(tensor)
print(out)
Pytorch实现ChannelPooling(通道池化)
最新推荐文章于 2023-07-14 16:43:35 发布