mobiefacenet添加注意力机制

https://github.com/Sunshine066/my-face-althorithm

完整代码已经放到github,包括添加现存各个注意力机制的mobiefacenet,欢迎查看指正。有问题沟通
下面只放部分注意力机制的代码。

class SqueezeExcitation(nn.Module):
    # squeeze_factor: int = 4:第一个FC层节点个数是输入特征矩阵的1/4
    def __init__(self, input_c: int, squeeze_factor: int = 4):
        super(SqueezeExcitation, self).__init__()
        # 第一个FC层节点个数,也要是8的整数倍
        squeeze_c = _make_divisible(input_c // squeeze_factor, 8)
        # print("yttest     %d" %squeeze_c)
        # 通过卷积核大小为1x1的卷积替代FC层,作用相同
        self.fc1 = nn.Conv2d(input_c, squeeze_c, 1)
        self.fc2 = nn.Conv2d(squeeze_c, input_c, 1)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        # x有很多channel,通过output_size=(1, 1)实现每个channel变成1个数字
        scale = nn.functional.adaptive_avg_pool2d(x, output_size=(1, 1))
        scale = self.fc1(scale)
        scale = nn.functional.relu(scale, inplace=True)
        scale = self.fc2(scale)
        # 此处的scale就是第二个FC层输出的数据
        scale = nn.functional.hardsigmoid(scale, inplace=True)
        return scale * x        # 和原输入相乘,得到SE模块的输出

yt add start end

yt add start,3Dse moudle,空间注意力模块

class conv_block(nn.Module):
    def __init__(self,ch_in,ch_out):
        super(conv_block,self).__init__()
        self.conv = nn.Sequential(
            nn.Conv2d(ch_in, ch_out, kernel_size=3,stride=1,padding=1,bias=True),
            nn.BatchNorm2d(ch_out),
            nn.ReLU(inplace=True),
            nn.Conv2d(ch_out, ch_out, kernel_size=3,stride=1,padding=1,bias=True),
            nn.BatchNorm2d(ch_out),
            nn.ReLU(inplace=True)
        )

    def forward(self,x):
        x = self.conv(x)
        return x


class SqueezeAttentionBlock(nn.Module):
    def __init__(self, ch_in, ch_out):
        super(SqueezeAttentionBlock, self).__init__()
        self.avg_pool = nn.AvgPool2d(kernel_size=2, stride=2)
        self.conv = conv_block(ch_in, ch_out)
        self.conv_atten = conv_block(ch_in, ch_out)
        self.upsample = nn.Upsample(scale_factor=2)

    def forward(self, x):
        print(x.shape)
        x_res = self.conv(x)
        print(x_res.shape)
        y = self.avg_pool(x)
        print(y.shape)
        y = self.conv_atten(y)
        print(y.shape)
        y = self.upsample(y)
        print(y.shape[0])
        # yt add, 解决奇数大小特征图下采样、上采样后无法恢复原始特征图大小的问题
        if y.shape != x_res.shape:
            y = torch.nn.functional.interpolate(y, size=[y.shape[2]+1,y.shape[3]+1], scale_factor=None, mode='nearest', align_corners=None,
                                            recompute_scale_factor=None)
        print(y.shape, x_res.shape)
        return (y * x_res) + y
  • 1
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值