paper: https://arxiv.org/pdf/2105.02358v2
code: https://paperswithcode.com/paper/beyond-self-attention-external-attention
简介:
摘要:
注意力机制,尤其是自注意力,在视觉任务的深度特征表示中发挥着越来越重要的作用。Self-attention 通过使用跨所有位置的成对相似性计算特征的加权和来更新每个位置的特征,以捕获单个样本中的长期依赖性。然而,self-attention 具有二次复杂度,忽略了不同样本之间的潜在相关性。本文提出了一种新的注意力机制,我们称之为外部注意力,基于两个外部的、小的、可学习的、共享的记忆,可以通过简单地使用两个级联线性层和两个归一化层轻松实现;它方便地取代了现有流行架构中的自我注意。外部注意力具有线性复杂度,隐含地考虑了所有数据样本之间的相关性。我们进一步将多头机制纳入外部注意力,为图像分类提供全 MLP 架构、外部注意力 MLP (EAMLP)。在图像分类、目标检测、语义分割、实例分割、图像生成和点云分析上的大量实验表明,我们的方法提供了与自我注意机制及其一些变体相当或更好的结果,计算和内存成本要低得多
结论:
本文介绍了外部注意力,这是一种新颖但有效的注意力机制,可用于各种视觉任务。外部注意力中采用的两个外部存储器单元可以被视为整个数据集的字典,并且能够在降低计算成本的同时学习更多具有代表性的输入特征。我们希望外部注意力将激发实际应用和研究其在 NLP 等其他领域的使用。
external-attention 结构图
The computational complexity of external attention is O(dSN ); as d and S are hyper-parameters, the proposed algorithm is linear in the number of pixels. In fact, we find that a small S, e.g. 64, works well in experiments. Thus, external attention is much more efficient than selfattention, allowing its direct application to large-scale inputs.
使用方式:大概在最后一层
代码:
官方代码:
# from: https://github.com/MenghaoGuo/EANet/blob/main/model_torch.py
class External_attention(nn.Module):
'''
Arguments:
c (int): The input and output channel number.
'''
def __init__(self, c):
super(External_attention, self).__init__()
self.conv1 = nn.Conv2d(c, c, 1)
self.k = 64
self.linear_0 = nn.Conv1d(c, self.k, 1, bias=False)
self.linear_1 = nn.Conv1d(self.k, c, 1, bias=False)
self.linear_1.weight.data = self.linear_0.weight.data.permute(1, 0, 2)
self.conv2 = nn.Sequential(
nn.Conv2d(c, c, 1, bias=False),
norm_layer(c))
for m in self.modules():
if isinstance(m, nn.Conv2d):
n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
m.weight.data.normal_(0, math.sqrt(2. / n))
elif isinstance(m, nn.Conv1d):
n = m.kernel_size[0] * m.out_channels
m.weight.data.normal_(0, math.sqrt(2. / n))
elif isinstance(m, _BatchNorm):
m.weight.data.fill_(1)
if m.bias is not None:
m.bias.data.zero_()
def forward(self, x):
idn = x
x = self.conv1(x)
b, c, h, w = x.size()
n = h*w
x = x.view(b, c, h*w) # b * c * n
attn = self.linear_0(x) # b, k, n
attn = F.softmax(attn, dim=-1) # b, k, n
attn = attn / (1e-9 + attn.sum(dim=1, keepdim=True)) # # b, k, n
x = self.linear_1(attn) # b, c, n
x = x.view(b, c, h, w)
x = self.conv2(x)
x = x + idn
x = F.relu(x)
return x
实现多头 注意力:
官方代码:
# from: https://github.com/MenghaoGuo/EANet/blob/main/multi_head_attention_torch.py
class Attention(nn.Module):
def __init__(self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0., proj_drop=0.):
super().__init__()
self.num_heads = num_heads
assert dim % num_heads == 0
self.coef = 4
self.trans_dims = nn.Linear(dim, dim * self.coef)
self.num_heads = self.num_heads * self.coef
self.k = 256 // self.coef
self.linear_0 = nn.Linear(dim * self.coef // self.num_heads, self.k)
self.linear_1 = nn.Linear(self.k, dim * self.coef // self.num_heads)
self.attn_drop = nn.Dropout(attn_drop)
self.proj = nn.Linear(dim * self.coef, dim)
self.proj_drop = nn.Dropout(proj_drop)
def forward(self, x):
B, N, C = x.shape
x = self.trans_dims(x) # B, N, C
x = x.view(B, N, self.num_heads, -1).permute(0, 2, 1, 3)
attn = self.linear_0(x)
attn = attn.softmax(dim=-2)
attn = attn / (1e-9 + attn.sum(dim=-1, keepdim=True))
attn = self.attn_drop(attn)
x = self.linear_1(attn).permute(0,2,1,3).reshape(B, N, -1)
x = self.proj(x)
x = self.proj_drop(x)
return x
一个集成 模块的代码仓库的代码(非官方):
# from: https://github.com/xmu-xiaoma666/External-Attention-pytorch/blob/master/model/attention/ExternalAttention.py
import numpy as np
import torch
from torch import nn
from torch.nn import init
class ExternalAttention(nn.Module):
def __init__(self, d_model,S=64):
super().__init__()
self.mk=nn.Linear(d_model,S,bias=False)
self.mv=nn.Linear(S,d_model,bias=False)
self.softmax=nn.Softmax(dim=1)
self.init_weights()
def init_weights(self):
for m in self.modules():
if isinstance(m, nn.Conv2d):
init.kaiming_normal_(m.weight, mode='fan_out')
if m.bias is not None:
init.constant_(m.bias, 0)
elif isinstance(m, nn.BatchNorm2d):
init.constant_(m.weight, 1)
init.constant_(m.bias, 0)
elif isinstance(m, nn.Linear):
init.normal_(m.weight, std=0.001)
if m.bias is not None:
init.constant_(m.bias, 0)
def forward(self, queries):
attn=self.mk(queries) #bs,n,S
attn=self.softmax(attn) #bs,n,S
attn=attn/torch.sum(attn,dim=2,keepdim=True) #bs,n,S
out=self.mv(attn) #bs,n,d_model
return out
if __name__ == '__main__':
input=torch.randn(50,49,512)
ea = ExternalAttention(d_model=512,S=8)
output=ea(input)
print(output.shape)