【文章链接】[2111.15193] Shunted Self-Attention via Multi-Scale Token Aggregation (arxiv.org)
现有的Transformer模型采用各种降采样策略来减少特征尺寸和内存消耗。例如,VIT方法在第一层进行16 × 16的下采样投影,并在得到的粗粒度和单尺度特征图上计算自注意力;因此,所产生的特征信息损失不可避免地降低了模型的性能。PVT则致力于在细粒度特征上计算自注意力,并通过合并token的空间缩减来降低成本;然而合并过多的token会导致来自小目标和背景噪声的token混合。这样的行为反过来使得模型在捕获小物体时效果较差。
SSA的多尺度注意力机制是通过将多个注意力头拆分成若干组来实现的。每个组占一个专用的注意力粒度。对于细粒度的组,SSA学习聚合很少的token,并保留更多的局部细节。对于剩下的粗粒度头组,SSA学习聚合大量的token,从而降低计算成本,同时保留捕获大型对象的能力。多粒度组共同学习多粒度信息,使得模型能够对多尺度对象进行有效建模。
1.回忆PVT中的spatialreduction attention (SRA)
SRA是通过KV中的降低token长度来降低计算成本。
(1)token的维度是(HW,C),表示的是HW个token,特征维度是C。在计算Q,K,V的时候,需要token与三个不同的(C,C)大小的可训练参数做矩阵乘法。现在的计算量是。得到Q,K,V三者的大小(HW,C)。
Q:(HW,C)*(C,C)----->(HW,C)
K:(HW,C)*(C,C)----->(HW,C)
V:(HW,C)*(C,C)----->(HW,C)
(2)Q与K的转置相乘,可以得到(HW,HW)大小的关系矩阵,计算量是。
(HW,C)*(C,HW)----->(HW,HW)
(3)关系矩阵与V相乘,计算量是。
(HW,HW)*(HW,C)----->(HW,C)
以上,如果K和V的大小不是(HW,C),而是(hw,C),那么以上(2)(3)的矩阵乘法就变成
(HW,C)*(C,hw)----->(HW,hw) 计算量HWhwC
(HW,hw)*(hw,C)----->(HW,C) 计算量HWhwC
由此可见,只要hw小于HW,计算量就会减少,以上就是SRA的概念。
PVT中是如何实现的呢?
class Attention(nn.Module):
def __init__(self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0., proj_drop=0., sr_ratio=1):
.....
#计算QKV时,可训练矩阵大小与第一步相同,没有进行空间减少
self.q = nn.Linear(dim, dim, bias=qkv_bias)
self.kv = nn.Linear(dim, dim * 2, bias=qkv_bias)
.....
self.sr_ratio = sr_ratio
if sr_ratio > 1:
#使用卷积操作,设置步长,来降低特征图的大小
self.sr = nn.Conv2d(dim, dim, kernel_size=sr_ratio, stride=sr_ratio)
self.norm = nn.LayerNorm(dim)
2.Shunted Self-Attention(SSA)
如图所示,SSA与PVT中的SRA不同。对于SSA,在同一个自注意力层中,不同的头对应的K,V长度不同,这样可以捕获不同粒度的信息。在实践中,采用具有卷积核大小和步长均为的卷积层来实现下采样。公式为:
具体来说,对于由 i 索引的不同头,键 K 和值 V 被下采样到不同的大小。LE(·) 是通过深度卷积对 V 值进行局部增强的部分(这个操作的作用是什么哇)。
当r变大时,K、V中的token被合并,K、V的长度更短,因此计算成本较低,但仍然保留了捕获大对象的能力。相反,当 r 变小时,保留了更多细节,但带来了更多计算成本。将各种 r 集成到一个自注意力层中使其能够捕获多粒度特征。
代码中,只是被分为两种尺度,与文章中table1一致。
if sr_ratio==8:
self.sr1 = nn.Conv2d(dim, dim, kernel_size=8, stride=8)
self.norm1 = nn.LayerNorm(dim)
self.sr2 = nn.Conv2d(dim, dim, kernel_size=4, stride=4)
self.norm2 = nn.LayerNorm(dim)
if sr_ratio==4:
self.sr1 = nn.Conv2d(dim, dim, kernel_size=4, stride=4)
self.norm1 = nn.LayerNorm(dim)
self.sr2 = nn.Conv2d(dim, dim, kernel_size=2, stride=2)
self.norm2 = nn.LayerNorm(dim)
if sr_ratio==2:
self.sr1 = nn.Conv2d(dim, dim, kernel_size=2, stride=2)
self.norm1 = nn.LayerNorm(dim)
self.sr2 = nn.Conv2d(dim, dim, kernel_size=1, stride=1)
self.norm2 = nn.LayerNorm(dim)
3.Detail-specific Feedforward Layers
在传统的前馈层中,全连接层是逐点的,无法学习交叉令牌信息。在这里,我们的目标是通过指定前馈层中的细节来补充局部信息。如图 6 所示,我们通过在前馈层的两个全连接层之间添加数据特定层来补充前馈层中的局部细节:
(注意:1、图片与公式相比,激活层的位置对不上。2、PVT指的是PVT_v2)
前馈层公式:(图片与公式相比,激活层的位置对不上)
代码与公式一致:
def forward(self, x, H, W):
x = self.fc1(x)
x = self.act(x + self.dwconv(x, H, W))
x = self.drop(x)
x = self.fc2(x)
x = self.drop(x)
return x