本文作者从训练深层ViT出发,通过大量实验发现深层block的Attention maps相似度很高,甚至一些层的Attention maps近乎相同,这导致深层ViT的效果不好。同时作者还发现多头注意力的不同头之间的相似度很低,于是提出了本文的idea,通过交互不同头之间的信息重新生成Attention maps,减少深层block的Attention maps之间的相似性,可以训练更深层的ViT。
同原ViT相比,仅替换了Self-Attention层:
而且改动也很简单,只是加了一个特征转换矩阵和标准化:
从公式角度来看:
增加了一个Θ ∈
(H代表头数),和一个Norm层。
从代码角度来看:
class ReAttention(nn.Module):
"""
It is observed that similarity along same batch of data is extremely large.
Thus can reduce the bs dimension when calculating the attention map.
"""
def __init__(self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0., proj_drop=0.,expansion_ratio = 3, apply_transform=True, transform_scale=False):
super().__init__()
self.num_heads = num_heads
head_dim = dim // num_heads
self.apply_transform = apply_transform
# NOTE scale factor was wrong in my original version, can set manually to be compat with prev weights
self.scale = qk_scale or head_dim ** -0.5
if apply_transform:
self.reatten_matrix = nn.Conv2d(self.num_heads,self.num_heads, 1, 1)
self.var_norm = nn.BatchNorm2d(self.num_heads)
self.qkv = nn.Linear(dim, dim * expansion_ratio, bias=qkv_bias)
self.reatten_scale = self.scale if transform_scale else 1.0
else:
self.qkv = nn.Linear(dim, dim * expansion_ratio, bias=qkv_bias)
self.attn_drop = nn.Dropout(attn_drop)
self.proj = nn.Linear(dim, dim)
self.proj_drop = nn.Dropout(proj_drop)
def forward(self, x, atten=None):
B, N, C = x.shape
# x = self.fc(x)
qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
q, k, v = qkv[0], qkv[1], qkv[2] # make torchscript happy (cannot use tensor as tuple)
attn = (q @ k.transpose(-2, -1)) * self.scale
attn = attn.softmax(dim=-1)
attn = self.attn_drop(attn)
if self.apply_transform:
attn = self.var_norm(self.reatten_matrix(attn)) * self.reatten_scale
attn_next = attn
x = (attn @ v).transpose(1, 2).reshape(B, N, C)
x = self.proj(x)
x = self.proj_drop(x)
return x, attn_next
后续作者进行一系列实验证明该方法比其他方法更work,包括增加嵌入向量长度、增加温度的自注意力以及dropping attentions。表明再注意力模块在几乎不增加计算量的基础上可以很好实现深层ViT,减少深层block的Attention maps的相似。同时对比其他模型,实现了SOTA,附上实验结果:
以上。