Shuffle Attention注意力机制详解

一、流程梳理

ShuffleAttention 是一种结合了 通道注意力(Channel Attention) 和 空间注意力(Spatial Attention) 的注意力机制。通过分组、计算注意力,并将其应用于输入特征图,从而提高特征表示的能力和网络的性能。其独特之处在于引入了 通道重排(Channel Shuffle) 技术,旨在提升模型的多样性和表示能力。

1、分组操作(Group Splitting)

       通过将输入特征图的通道维度分成多个小组(由参数 G 确定),并进行分组处理。这样做的目的是减少计算量,同时也能够提升网络的表达能力。

2、通道注意力(Channel Attention)

       使用全局平均池化 (avg_pool) 操作得到每组特征图的全局信息,并通过学习到的权重和偏置调整每组特征的通道重要性。最后通过 Sigmoid 激活函数将得到的通道注意力应用于对应的特征图。

3、空间注意力(Spatial Attention)

       使用组归一化 (GroupNorm) 计算每组特征的空间注意力。空间注意力反映了在空间维度上,哪些区域更加重要。经过 Sigmoid 激活后,得到的空间注意力应用于特征图的空间维度。

4、通道重排(Channel Shuffle)

       在通道注意力和空间注意力的计算之后,使用通道重排操作将通道顺序打乱,以增强模型对通道特征的表达能力。通道重排通过重新排列通道的顺序,有助于引入多样性和对不同特征的联合建模。

二、实现细节

2.1 SA模块概述

       采用通道分割来并行处理每个组的子特征。
对于通道注意力分支,使用全局平均池化(GAP)生成通道级统计数据,然后使用一对参数来缩放和偏移通道向量;对于空间注意力分支,采用组归一化来生成空间级统计数据,然后创建一个类似于通道分支的紧凑特征。然后将两个分支连接起来。之后,所有子特征被聚合,最后使用一个通道洗牌操作符来实现不同子特征之间的信息交互。

2.2 具体实现

2.2.1. 通道分组:将输入特征图 x 的通道数通过 G 进行分组,得到多个子特征图。这个过程通过 x.view(b * G, -1, h, w) 实现,将通道数按 G 分组进行重塑。

2.2.2. 通道拆分:将每个分组的特征图拆成两部分(x_0 和 x_1),分别计算通道注意力和空间注意力。

2.2.3. 通道注意力:

  • 对 x_0 使用全局平均池化操作得到每个小组的全局信息。

  • 对得到的全局信息进行线性变换(通过学习的权重和偏置)来计算每个通道的重要性。

  • 使用 Sigmoid 激活函数得到通道注意力,并应用到 x_0 上,调整通道特征的权重。

2.2.4. 空间注意力:

  • 对 x_1 使用组归一化计算空间注意力。

  • 使用学习到的权重和偏置进行调整,再通过 Sigmoid 激活函数生成空间注意力。

  • 将空间注意力应用到 x_1 上,调整空间维度的特征。

2.2.5. 拼接操作:将经过通道注意力和空间注意力加权的 x_0 和 x_1 拼接在一起,形成一个新的特征图。

2.2.6. 通道重排:通过 channel_shuffle 操作将拼接后特征图的通道进行重排,从而增加特征的多样性和表达能力。该操作通过将通道分组并交换位置来实现。

2.2.7. 输出:返回重排后的输出特征图。

def shuffleAttention(x,cw,cb,sw,sb,6):
    # x:input features with shape [N,c,H,W]
    #cw,cb,sw,sb:parameters with shape [1,C//2G,1,1]# G: number of groups
    N,_,H,W=x
    # group into subfeatures
    x=x.reshape(NG,-1,H,W)
    # channel split
    x0,x1=x.chunk(2,dim=1)
    # channel attentionxn = avg_pool(x0)xn=cw*xn +cbxn=x0*sigmoid(xn)
    # spatial attentionxs =GroupNorm(x1)Xs = SW * Xs + sbxs =x1*sigmoid(xs)
    # concatenate and aggragate
    out =torch.cat([xn,xs],dim=1)
    out =out.shape(N,-1,H,W)
    # channel shuffleout = channel shuffle(out,2)return out

三、代码实现

import numpy as np
import torch
from torch import nn
from torch.nn import init
from torch.nn.parameter import Parameter


class ShuffleAttention(nn.Module):

    def __init__(self, channel=512,reduction=16,G=8):

        """        
        初始化 Shuffle Attention 模块        
        :param channel: 输入特征图的通道数        
        :param reduction: 用于通道注意力的缩放因子        
        :param G: 分组数,表示将通道数分成 G 组进行处理        
        """
        super().__init__()
        self.G=G
        self.channel=channel

        # 全局平均池化,用于通道注意力的计算
        self.avg_pool = nn.AdaptiveAvgPool2d(1)

        # 对输入特征进行分组的归一化,降低计算量
        self.gn = nn.GroupNorm(channel // (2 * G), channel // (2 * G))

        # 定义通道注意力和空间注意力的权重和偏置参数
        self.cweight = Parameter(torch.zeros(1, channel // (2 * G), 1, 1))
        self.cbias = Parameter(torch.ones(1, channel // (2 * G), 1, 1))
        self.sweight = Parameter(torch.zeros(1, channel // (2 * G), 1, 1))
        self.sbias = Parameter(torch.ones(1, channel // (2 * G), 1, 1))

        self.sigmoid=nn.Sigmoid()


    def init_weights(self):
        # 初始化模块的权重
        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                # 使用Kaiming正态分布初始化卷积层权重
                init.kaiming_normal_(m.weight, mode='fan_out')
                if m.bias is not None:
                    # 将卷积层的偏置初始化为零
                    init.constant_(m.bias, 0)
            elif isinstance(m, nn.BatchNorm2d):
                # 批归一化层的权重初始化为1
                init.constant_(m.weight, 1)
                # 批归一化层的偏置初始化为0
                init.constant_(m.bias, 0)
            elif isinstance(m, nn.Linear):
                # 线性层的权重初始化为正态分布
                init.normal_(m.weight, std=0.001)
                if m.bias is not None:
                    # 线性层的偏置初始化为零
                    init.constant_(m.bias, 0)


    @staticmethod
    def channel_shuffle(x, groups):
        # :param x: 输入张量,形状为 (b, c, h, w)        
        # :param groups: 分组数        
        # :return: 重排后的张量
        b, c, h, w = x.shape
        x = x.reshape(b, groups, -1, h, w)    # 按照 groups 将通道维度进行重塑
        x = x.permute(0, 2, 1, 3, 4)          # 调整维度顺序,使通道组进行交换

        # flatten
        x = x.reshape(b, -1, h, w)

        return x

    def forward(self, x):
        # :param x: 输入张量,形状为 (batch_size, channel, height, width)        
        # :return: 输出张量
        b, c, h, w = x.size()

        # 将输入按照 G 进行分组
        x=x.view(b*self.G,-1,h,w) #bs*G,c//G,h,w

        #channel_split
        x_0,x_1=x.chunk(2,dim=1) #bs*G,c//(2*G),h,w

        #channel attention
        x_channel=self.avg_pool(x_0) #bs*G,c//(2*G),1,1
        # 通过权重和偏置进行调整
        x_channel=self.cweight*x_channel+self.cbias #bs*G,c//(2*G),1,1
        # 应用sigmoid激活函数得到通道注意力
        x_channel=x_0*self.sigmoid(x_channel)

        #spatial attention
        x_spatial=self.gn(x_1) #bs*G,c//(2*G),h,w
        x_spatial=self.sweight*x_spatial+self.sbias #bs*G,c//(2*G),h,w
        x_spatial=x_1*self.sigmoid(x_spatial) #bs*G,c//(2*G),h,w

        # concatenate along channel axis
        # 将通道注意力和空间注意力的结果在通道维度上拼接
        out=torch.cat([x_channel,x_spatial],dim=1)  #bs*G,c//G,h,w
        out=out.contiguous().view(b,-1,h,w)

        # channel shuffle
        out = self.channel_shuffle(out, 2)
        return out

# 模块测试
if __name__ == '__main__':
    
    # 创建随机张量
    input=torch.randn(50,512,7,7)

    # 创建 ShuffleAttention 实例
    se = ShuffleAttention(channel=512,G=8)

    output=se(input)
    print(output.shape)

论文题目Efficient Attention: Attention with Linear Complexities

论文链接:https://arxiv.org/pdf/2102.00240

官方github:https://github.com/wofmanaf/SA-Net

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

ghx3110

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值