弥补了SE模块缺少位置信息的缺点,通过对w,h方向进行全局平均池化,使得特征层能够获取到位置信息
class JH(nn.Module):
def __init__(self):
super().__init__()
self.relu6=nn.ReLU6(inplace=True)
def forward(self,x):
return x*(self.relu6(x+3)/6)
class Coord(nn.Module):
def __init__(self,in_channel,out_channel,ratio):
super().__init__()
self.avg_w=nn.AdaptiveAvgPool2d((None,1))# c,h,1
self.avg_h=nn.AdaptiveAvgPool2d((1,None))# c,1,w
self.conv=nn.Conv2d(in_channels=in_channel,out_channels=in_channel//ratio,kernel_size=1)
self.bn=nn.BatchNorm2d(in_channel//ratio)
self.JH=JH()
self.conv1=nn.Conv2d(in_channels=in_channel//ratio,out_channels=out_channel,kernel_size=1)
self.conv2=nn.Conv2d(in_channels=in_channel//ratio,out_channels=out_channel,kernel_size=1)
self.sigmoid=nn.Sigmoid()
def forward(self,x):
identity=x
_,_,h,w=x.shape
# b,c,h,1
x_w=self.avg_w(x)
# b,c,1,w => b,c,w,1
x_h=self.avg_h(x).permute(0,1,3,2).contiguous()
# b,c,(h+w),1
x=torch.concat([x_w,x_h],dim=2)
x=self.conv(x)
x=self.bn(x)
x=self.JH(x)
x_w,x_h=torch.split(x,[h,w],dim=2)
x_w=self.sigmoid(self.conv1(x_w))
x_h=self.sigmoid(self.conv2(x_h)).permute(0,1,3,2).contiguous()
return identity*x_w*x_h
x=torch.Tensor(64,16,224,244)
func=Coord(16,16,4)
a=func(x)
print(a.shape)