论文地址:https://arxiv.org/abs/2112.05561#/
摘要:
为了提升各种计算机视觉任务的性能,已经研究了多种注意力机制。然而,先前的方法忽视了在通道和空间两个层面保留信息的重要性,从而影响了跨维度的交互。因此,我们提出了一种全局注意力机制,它通过减少信息损失并放大全局交互表示,来提升深度神经网络的性能。我们为通道注意力引入了3D置换和多层感知器,并设计了一个卷积空间注意力子模块。在CIFAR-100和ImageNet-1K图像分类任务上评估我们提出的机制时,结果显示,无论是使用ResNet还是轻量级的MobileNet,我们的方法都稳定地优于最近的几种注意力机制。
简而言之,我们设计了一种新的全局注意力机制,它能更好地保留和处理图像在通道和空间两个维度的信息,进而提升深度神经网络在图像分类任务上的性能。在多个数据集和模型上的测试都显示,我们的方法比其他注意力机制更为出色。
结构图:
Pytorch源码:
import torch.nn as nn
import torch
class GAM_Attention(nn.Module):
def __init__(self, in_channels, out_channels, rate=4):
super(GAM_Attention, self).__init__()
self.channel_attention = nn.Sequential(
nn.Linear(in_channels, int(in_channels / rate)),
nn.ReLU(inplace=True),
nn.Linear(int(in_channels / rate), in_channels)
)
self.spatial_attention = nn.Sequential(
nn.Conv2d(in_channels, int(in_channels / rate), kernel_size=7, padding=3),
nn.BatchNorm2d(int(in_channels / rate)),
nn.ReLU(inplace=True),
nn.Conv2d(int(in_channels / rate), out_channels, kernel_size=7, padding=3),
nn.BatchNorm2d(out_channels)
)
def forward(self, x):
b, c, h, w = x.shape
x_permute = x.permute(0, 2, 3, 1).view(b, -1, c)
x_att_permute = self.channel_attention(x_permute).view(b, h, w, c)
x_channel_att = x_att_permute.permute(0, 3, 1, 2)
x = x * x_channel_att
x_spatial_att = self.spatial_attention(x).sigmoid()
out = x * x_spatial_att
return out
if __name__ == '__main__':
x = torch.randn(1, 64, 32, 32)
b, c, h, w = x.shape
net = GAM_Attention(in_channels=c, out_channels=c)
y = net(x)
print(y.shape)