CA注意力模块

在这里插入图片描述
弥补了SE模块缺少位置信息的缺点,通过对w,h方向进行全局平均池化,使得特征层能够获取到位置信息

class JH(nn.Module):
    def __init__(self):
        super().__init__()
        self.relu6=nn.ReLU6(inplace=True)
    def forward(self,x):
        return x*(self.relu6(x+3)/6)



class Coord(nn.Module):
    def __init__(self,in_channel,out_channel,ratio):
        super().__init__()
        self.avg_w=nn.AdaptiveAvgPool2d((None,1))# c,h,1
        self.avg_h=nn.AdaptiveAvgPool2d((1,None))# c,1,w
        self.conv=nn.Conv2d(in_channels=in_channel,out_channels=in_channel//ratio,kernel_size=1)
        self.bn=nn.BatchNorm2d(in_channel//ratio)
        self.JH=JH()
        self.conv1=nn.Conv2d(in_channels=in_channel//ratio,out_channels=out_channel,kernel_size=1)
        self.conv2=nn.Conv2d(in_channels=in_channel//ratio,out_channels=out_channel,kernel_size=1)
        self.sigmoid=nn.Sigmoid()
    def forward(self,x):
        identity=x
        _,_,h,w=x.shape
        # b,c,h,1
        x_w=self.avg_w(x)
        # b,c,1,w  => b,c,w,1
        x_h=self.avg_h(x).permute(0,1,3,2).contiguous()
        # b,c,(h+w),1
        x=torch.concat([x_w,x_h],dim=2)
        x=self.conv(x)
        x=self.bn(x)
        x=self.JH(x)
        x_w,x_h=torch.split(x,[h,w],dim=2)
        x_w=self.sigmoid(self.conv1(x_w))
        x_h=self.sigmoid(self.conv2(x_h)).permute(0,1,3,2).contiguous()
        return identity*x_w*x_h
        
x=torch.Tensor(64,16,224,244)
func=Coord(16,16,4)
a=func(x)
print(a.shape)
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值