一、流程梳理
ShuffleAttention 是一种结合了 通道注意力(Channel Attention) 和 空间注意力(Spatial Attention) 的注意力机制。通过分组、计算注意力,并将其应用于输入特征图,从而提高特征表示的能力和网络的性能。其独特之处在于引入了 通道重排(Channel Shuffle) 技术,旨在提升模型的多样性和表示能力。
1、分组操作(Group Splitting)
通过将输入特征图的通道维度分成多个小组(由参数 G 确定),并进行分组处理。这样做的目的是减少计算量,同时也能够提升网络的表达能力。
2、通道注意力(Channel Attention)
使用全局平均池化 (avg_pool) 操作得到每组特征图的全局信息,并通过学习到的权重和偏置调整每组特征的通道重要性。最后通过 Sigmoid 激活函数将得到的通道注意力应用于对应的特征图。
3、空间注意力(Spatial Attention)
使用组归一化 (GroupNorm) 计算每组特征的空间注意力。空间注意力反映了在空间维度上,哪些区域更加重要。经过 Sigmoid 激活后,得到的空间注意力应用于特征图的空间维度。
4、通道重排(Channel Shuffle)
在通道注意力和空间注意力的计算之后,使用通道重排操作将通道顺序打乱,以增强模型对通道特征的表达能力。通道重排通过重新排列通道的顺序,有助于引入多样性和对不同特征的联合建模。
二、实现细节
2.1 SA模块概述
采用通道分割来并行处理每个组的子特征。 对于通道注意力分支,使用全局平均池化(GAP)生成通道级统计数据,然后使用一对参数来缩放和偏移通道向量;对于空间注意力分支,采用组归一化来生成空间级统计数据,然后创建一个类似于通道分支的紧凑特征。然后将两个分支连接起来。之后,所有子特征被聚合,最后使用一个通道洗牌操作符来实现不同子特征之间的信息交互。
2.2 具体实现
2.2.1. 通道分组:将输入特征图 x 的通道数通过 G 进行分组,得到多个子特征图。这个过程通过 x.view(b * G, -1, h, w) 实现,将通道数按 G 分组进行重塑。
2.2.2. 通道拆分:将每个分组的特征图拆成两部分(x_0 和 x_1),分别计算通道注意力和空间注意力。
2.2.3. 通道注意力:
-
对 x_0 使用全局平均池化操作得到每个小组的全局信息。
-
对得到的全局信息进行线性变换(通过学习的权重和偏置)来计算每个通道的重要性。
-
使用 Sigmoid 激活函数得到通道注意力,并应用到 x_0 上,调整通道特征的权重。
2.2.4. 空间注意力:
-
对 x_1 使用组归一化计算空间注意力。
-
使用学习到的权重和偏置进行调整,再通过 Sigmoid 激活函数生成空间注意力。
-
将空间注意力应用到 x_1 上,调整空间维度的特征。
2.2.5. 拼接操作:将经过通道注意力和空间注意力加权的 x_0 和 x_1 拼接在一起,形成一个新的特征图。
2.2.6. 通道重排:通过 channel_shuffle 操作将拼接后特征图的通道进行重排,从而增加特征的多样性和表达能力。该操作通过将通道分组并交换位置来实现。
2.2.7. 输出:返回重排后的输出特征图。
def shuffleAttention(x,cw,cb,sw,sb,6):
# x:input features with shape [N,c,H,W]
#cw,cb,sw,sb:parameters with shape [1,C//2G,1,1]# G: number of groups
N,_,H,W=x
# group into subfeatures
x=x.reshape(NG,-1,H,W)
# channel split
x0,x1=x.chunk(2,dim=1)
# channel attentionxn = avg_pool(x0)xn=cw*xn +cbxn=x0*sigmoid(xn)
# spatial attentionxs =GroupNorm(x1)Xs = SW * Xs + sbxs =x1*sigmoid(xs)
# concatenate and aggragate
out =torch.cat([xn,xs],dim=1)
out =out.shape(N,-1,H,W)
# channel shuffleout = channel shuffle(out,2)return out
三、代码实现
import numpy as np
import torch
from torch import nn
from torch.nn import init
from torch.nn.parameter import Parameter
class ShuffleAttention(nn.Module):
def __init__(self, channel=512,reduction=16,G=8):
"""
初始化 Shuffle Attention 模块
:param channel: 输入特征图的通道数
:param reduction: 用于通道注意力的缩放因子
:param G: 分组数,表示将通道数分成 G 组进行处理
"""
super().__init__()
self.G=G
self.channel=channel
# 全局平均池化,用于通道注意力的计算
self.avg_pool = nn.AdaptiveAvgPool2d(1)
# 对输入特征进行分组的归一化,降低计算量
self.gn = nn.GroupNorm(channel // (2 * G), channel // (2 * G))
# 定义通道注意力和空间注意力的权重和偏置参数
self.cweight = Parameter(torch.zeros(1, channel // (2 * G), 1, 1))
self.cbias = Parameter(torch.ones(1, channel // (2 * G), 1, 1))
self.sweight = Parameter(torch.zeros(1, channel // (2 * G), 1, 1))
self.sbias = Parameter(torch.ones(1, channel // (2 * G), 1, 1))
self.sigmoid=nn.Sigmoid()
def init_weights(self):
# 初始化模块的权重
for m in self.modules():
if isinstance(m, nn.Conv2d):
# 使用Kaiming正态分布初始化卷积层权重
init.kaiming_normal_(m.weight, mode='fan_out')
if m.bias is not None:
# 将卷积层的偏置初始化为零
init.constant_(m.bias, 0)
elif isinstance(m, nn.BatchNorm2d):
# 批归一化层的权重初始化为1
init.constant_(m.weight, 1)
# 批归一化层的偏置初始化为0
init.constant_(m.bias, 0)
elif isinstance(m, nn.Linear):
# 线性层的权重初始化为正态分布
init.normal_(m.weight, std=0.001)
if m.bias is not None:
# 线性层的偏置初始化为零
init.constant_(m.bias, 0)
@staticmethod
def channel_shuffle(x, groups):
# :param x: 输入张量,形状为 (b, c, h, w)
# :param groups: 分组数
# :return: 重排后的张量
b, c, h, w = x.shape
x = x.reshape(b, groups, -1, h, w) # 按照 groups 将通道维度进行重塑
x = x.permute(0, 2, 1, 3, 4) # 调整维度顺序,使通道组进行交换
# flatten
x = x.reshape(b, -1, h, w)
return x
def forward(self, x):
# :param x: 输入张量,形状为 (batch_size, channel, height, width)
# :return: 输出张量
b, c, h, w = x.size()
# 将输入按照 G 进行分组
x=x.view(b*self.G,-1,h,w) #bs*G,c//G,h,w
#channel_split
x_0,x_1=x.chunk(2,dim=1) #bs*G,c//(2*G),h,w
#channel attention
x_channel=self.avg_pool(x_0) #bs*G,c//(2*G),1,1
# 通过权重和偏置进行调整
x_channel=self.cweight*x_channel+self.cbias #bs*G,c//(2*G),1,1
# 应用sigmoid激活函数得到通道注意力
x_channel=x_0*self.sigmoid(x_channel)
#spatial attention
x_spatial=self.gn(x_1) #bs*G,c//(2*G),h,w
x_spatial=self.sweight*x_spatial+self.sbias #bs*G,c//(2*G),h,w
x_spatial=x_1*self.sigmoid(x_spatial) #bs*G,c//(2*G),h,w
# concatenate along channel axis
# 将通道注意力和空间注意力的结果在通道维度上拼接
out=torch.cat([x_channel,x_spatial],dim=1) #bs*G,c//G,h,w
out=out.contiguous().view(b,-1,h,w)
# channel shuffle
out = self.channel_shuffle(out, 2)
return out
# 模块测试
if __name__ == '__main__':
# 创建随机张量
input=torch.randn(50,512,7,7)
# 创建 ShuffleAttention 实例
se = ShuffleAttention(channel=512,G=8)
output=se(input)
print(output.shape)
论文题目:Efficient Attention: Attention with Linear Complexities
论文链接:https://arxiv.org/pdf/2102.00240
官方github:https://github.com/wofmanaf/SA-Net