【即插即用】SGE注意力机制(附源码)

原文链接:

https://arxiv.org/abs/1905.09646

源码链接:

https://github.com/implus/PytorchInsight

摘要简介:

在图像识别领域,卷积神经网络(CNN)通过收集和整合复杂对象的层次化和不同部分的语义子特征来生成特征表示。这些子特征通常以分组的形式分布在每一层的特征向量中,代表不同的语义实体。然而,这些子特征的激活常常受到相似模式和噪声背景的空间影响,从而可能导致定位和识别的错误。

为了解决这一问题,研究者们提出了一种名为空间分组增强(SGE)的模块。SGE模块可以为每个语义组中的每个空间位置生成一个注意力因子,以调整每个子特征的重要性。通过这种方式,每个单独的组都能够自主地增强其学习到的表达,并抑制可能的噪声。这些注意力因子仅由组内全局和局部特征描述符之间的相似性来指导,因此SGE模块的设计非常轻量级,几乎不需要额外的参数和计算。

尽管SGE组件仅通过类别监督进行训练,但它在突出显示具有各种高阶语义的多个活跃区域方面表现出色,如狗的眼睛、鼻子等。当与流行的CNN骨干网络结合使用时,SGE能够显著提高图像识别任务的性能。具体来说,基于ResNet50骨干网络,SGE在ImageNet基准测试中实现了1.2%的Top-1准确率提升,并在广泛的检测器(Faster/Mask/Cascade RCNN和RetinaNet)上,于COCO基准测试中获得了1.0∼2.0%的AP增益。相关的代码和预训练模型已经公开可用。


模型结构:

Pytorch版源码:

# Spatial Group-wise Enhance主要是用在语义分割上,所以在检测上的效果一般,没有带来多少提升
import torch
from torch import nn
from torch.nn import init


class SpatialGroupEnhance(nn.Module):

    def __init__(self, groups):
        super().__init__()
        self.groups = groups
        self.avg_pool = nn.AdaptiveAvgPool2d(1)
        self.weight = nn.Parameter(torch.zeros(1, groups, 1, 1))
        self.bias = nn.Parameter(torch.zeros(1, groups, 1, 1))
        self.sig = nn.Sigmoid()
        self.init_weights()

    def init_weights(self):
        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                init.kaiming_normal_(m.weight, mode='fan_out')
                if m.bias is not None:
                    init.constant_(m.bias, 0)
            elif isinstance(m, nn.BatchNorm2d):
                init.constant_(m.weight, 1)
                init.constant_(m.bias, 0)
            elif isinstance(m, nn.Linear):
                init.normal_(m.weight, std=0.001)
                if m.bias is not None:
                    init.constant_(m.bias, 0)

    def forward(self, x):
        b, c, h, w = x.shape
        x = x.view(b * self.groups, -1, h, w)  # bs*g,dim//g,h,w
        xn = x * self.avg_pool(x)  # bs*g,dim//g,h,w
        xn = xn.sum(dim=1, keepdim=True)  # bs*g,1,h,w
        t = xn.view(b * self.groups, -1)  # bs*g,h*w

        t = t - t.mean(dim=1, keepdim=True)  # bs*g,h*w
        std = t.std(dim=1, keepdim=True) + 1e-5
        t = t / std  # bs*g,h*w
        t = t.view(b, self.groups, h, w)  # bs,g,h*w

        t = t * self.weight + self.bias  # bs,g,h*w
        t = t.view(b * self.groups, 1, h, w)  # bs*g,1,h*w
        x = x * self.sig(t)
        x = x.view(b, c, h, w)
        return x

if __name__ == '__main__':
    input = torch.randn(2, 32, 512, 512)
    SGE = SpatialGroupEnhance(groups=input.size(1))
    output = SGE(input)
    print(output.shape)

  • 12
    点赞
  • 3
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值