【即插即用】ShuffleAttention注意力机制(附源码)

原文链接:

https://arxiv.org/pdf/2102.00240.pdf

源码地址:

https://github.com/wofmanaf/SA-Ne

摘要简介:

注意力机制让神经网络能够准确关注输入的所有相关元素,已成为提高深度神经网络性能的关键组件。在计算机视觉研究中,主要有两种广泛使用的注意力机制:空间注意力和通道注意力。它们分别旨在捕捉像素级的成对关系和通道依赖性。虽然将它们融合在一起可能比单独使用它们表现更好,但这会不可避免地增加计算开销。

在本文中,我们提出了一个高效的Shuffle Attention(SA)模块来解决这个问题。该模块采用Shuffle单元有效地结合了两种注意力机制。具体来说,SA首先将通道维度分组为多个子特征,然后并行处理它们。接着,对于每个子特征,SA使用Shuffle单元来描绘空间和通道维度上的特征依赖性。之后,所有子特征被聚合,并采用“通道Shuffle”操作符来使不同子特征之间的信息得以交流。

提出的SA模块既高效又有效。例如,与骨干网络ResNet50相比,SA的参数和计算量分别为300与25.56M,以及2.76e-3 GFLOPs与4.12 GFLOPs,但Top-1准确率提高了1.34%以上。在ImageNet-1k分类、MS COCO目标检测和实例分割等常用基准测试上的大量实验结果表明,所提出的SA在保持较低模型复杂度的同时,显著优于当前的SOTA方法,实现了更高的精度。

模型结构图:

Pytorch版源码:
import torch
from torch import nn
from torch.nn import init
from torch.nn.parameter import Parameter


class ShuffleAttention(nn.Module):

    def __init__(self, channel=512, G=8):
        super().__init__()
        self.G = G
        self.channel = channel
        self.avg_pool = nn.AdaptiveAvgPool2d(1)
        self.gn = nn.GroupNorm(channel // (2 * G), channel // (2 * G))
        self.cweight = Parameter(torch.zeros(1, channel // (2 * G), 1, 1))
        self.cbias = Parameter(torch.ones(1, channel // (2 * G), 1, 1))
        self.sweight = Parameter(torch.zeros(1, channel // (2 * G), 1, 1))
        self.sbias = Parameter(torch.ones(1, channel // (2 * G), 1, 1))
        self.sigmoid = nn.Sigmoid()

    def init_weights(self):
        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                init.kaiming_normal_(m.weight, mode='fan_out')
                if m.bias is not None:
                    init.constant_(m.bias, 0)
            elif isinstance(m, nn.BatchNorm2d):
                init.constant_(m.weight, 1)
                init.constant_(m.bias, 0)
            elif isinstance(m, nn.Linear):
                init.normal_(m.weight, std=0.001)
                if m.bias is not None:
                    init.constant_(m.bias, 0)

    @staticmethod
    def channel_shuffle(x, groups):
        b, c, h, w = x.shape
        x = x.reshape(b, groups, -1, h, w)
        x = x.permute(0, 2, 1, 3, 4)

        # 扁平化
        x = x.reshape(b, -1, h, w)

        return x

    def forward(self, x):
        b, c, h, w = x.size()
        # 将通道分成子特征
        x = x.view(b * self.G, -1, h, w)  # bs*G,c//G,h,w

        # 通道分割
        x_0, x_1 = x.chunk(2, dim=1)  # bs*G,c//(2*G),h,w

        # 通道注意力
        x_channel = self.avg_pool(x_0)  # bs*G,c//(2*G),1,1
        x_channel = self.cweight * x_channel + self.cbias  # bs*G,c//(2*G),1,1
        x_channel = x_0 * self.sigmoid(x_channel)

        # 空间注意力
        x_spatial = self.gn(x_1)  # bs*G,c//(2*G),h,w
        x_spatial = self.sweight * x_spatial + self.sbias  # bs*G,c//(2*G),h,w
        x_spatial = x_1 * self.sigmoid(x_spatial)  # bs*G,c//(2*G),h,w

        # 沿通道轴拼接
        out = torch.cat([x_channel, x_spatial], dim=1)  # bs*G,c//G,h,w
        out = out.contiguous().view(b, -1, h, w)

        # 通道混洗
        out = self.channel_shuffle(out, 2)
        return out


if __name__ == '__main__':
    input = torch.randn(2, 32, 512, 512)
    SA = ShuffleAttention(channel=input.size(1))
    output = SA(input)
    print(output.shape)

  • 12
    点赞
  • 5
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
L-注意力机制是一种结合了长短期记忆网络(LSTM)和注意力机制神经网络模型。引用和中提到了一些基于LSTM和注意力机制的时间序列预测的实现源码和数据。 LSTM是一种递归神经网络,被广泛应用于序列数据的建模和预测。它通过门控单元的设计,能够有效地捕捉序列中的长期依赖关系。而注意力机制则是一种机制,可以使模型自动地关注输入序列中的重要部分。它通过给予不同输入部分不同的权重,使模型能够更加集中地处理关键信息。 LSTM-注意力机制结合了LSTM和注意力机制的优点,能够在处理时间序列数据时更好地捕捉序列中的重要信息,提高预测准确性。这种模型在诸如文本翻译、语音识别和股票预测等任务中得到了广泛的应用。 引用中提到了神经机器翻译(NMT)作为LSTM-注意力机制的一个应用示例。在NMT中,LSTM-注意力机制被用来将源语言句子映射成一个固定长度的向量表示,并基于该向量生成目标语言的翻译。通过引入注意力机制,NMT能够更好地处理长句子和复杂语言结构,提高翻译质量。 最后,引用中提到了注意力机制深度学习的最新趋势之一。注意力机制的引入使得神经网络能够更加灵活地处理输入序列中的不同部分,提高了模型的表现和效果。 综上所述,LSTM-注意力机制是一种结合了长短期记忆网络和注意力机制神经网络模型,用于处理时间序列数据和任务,如文本翻译、语音识别和股票预测等。它能够更好地捕捉序列中的重要信息,提高预测准确性,并在深度学习领域具有广泛的应用前景。<span class="em">1</span><span class="em">2</span><span class="em">3</span><span class="em">4</span>

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值