DeepViT: Towards Deeper Vision Transformer
code
论文贡献
•我们深入研究了视觉转换器的行为,并观察到它们无法持续受益于将更多层堆叠为CNN。我们进一步确定了这种反直觉现象背后的潜在原因,并首次得出注意力崩溃的结论。
•我们提出了再注意,这是一种简单而有效的注意机制,考虑不同注意头之间的信息交换。
R
e
−
A
t
t
e
n
t
i
o
n
(
Q
,
K
,
V
)
=
N
o
r
m
(
Θ
T
(
S
o
f
t
m
a
x
(
Q
K
T
d
)
)
)
V
Re-Attention(Q,K,V)=Norm(Θ^T(Softmax(\frac{QK^T}{\sqrt{d}})))V
Re−Attention(Q,K,V)=Norm(ΘT(Softmax(dQKT)))V
•据我们所知,我们是第一家在ImageNet-1k上成功培训32块block的ViT
Attention Collapse
由于深部CNN的成功,我们对ViTs性能随深度增加的变化进行了系统研究。在不丧失一般性的情况下,我们首先按照[37]中的常见做法,将隐藏维度和头数分别固定为384和12。然后,我们堆叠不同数量的变压器块(从12到32不等),以构建多个对应不同深度的ViT模型。图像分类的总体性能在ImageNet上进行了评估【18】,并在图1中进行了总结。正如性能曲线所证明的那样,我们惊讶地发现,随着模型的深入,分类精度提高得很慢,饱和得很快。更具体地说,我们可以观察到,在使用24个变压器块后,改善停止了。这种现象表明,现有的VIT很难从更深层次的体系结构中获益。
这样的问题与直觉相悖,值得探索,因为CNN在早期开发阶段也观察到了类似的问题(即如何有效训练更深层次的模型),但后来得到了妥善解决。通过深入研究transfromer体系结构,我们想强调的是,自我注意机制在ViTs中起着关键作用,这使其与CNN显著不同。因此,我们从研究自我注意,或者更具体地说,随着模型的深入,生成的注意力图A也会发生变化。
为了衡量层间注意图的演变,我们计算了不同层的注意图之间的以下跨层相似性:
M h , t p , q = A h , : , t p T A h , : , t q ∣ ∣ A h , : , t p ∣ ∣ ∣ ∣ A h , : , t q ∣ ∣ M_{h,t}^{p,q}=\frac{ {A_{h,:,t}^p}^T {A^q_{h,:,t}} }{ || {A^p_{h,:,t}}|| \ ||{A^q_{h,:,t}}|| } Mh,tp,q=∣∣Ah,:,tp∣∣ ∣∣Ah,:,tq∣∣Ah,:,tpTAh,:,tq
其中,
M
p
,
q
M^{p,q}
Mp,q是p层和q层注意图之间的余弦相似矩阵。每个元素
M
h
,
t
p
,
q
M_{h,t}^{p,q}
Mh,tp,q测量头部h和标记t的注意相似性。考虑一个特定的自我注意层及其第h个头部,
A
h
,
:
,
t
∗
A^*_{h,:,t}
Ah,:,t∗是t维向量用于表示输入tonken t对于每个输出T令牌的贡献。因此,
M
h
,
t
p
,
q
M_{h,t}^{p,q}
Mh,tp,q提供了一个关于一个标记的贡献如何从p层到q层变化的适当度量。当
M
h
,
t
p
,
q
M_{h,t}^{p,q}
Mh,tp,q等于1时,这意味着token t在p层和q层的自我注意中起着完全相同的作用。
给定等式。(2) 然后,我们在ImageNet-1k上训练了一个包含32个变换块的ViT模型,并研究了所有注意图之间的上述相似性。
如图3(a)所示,第17块后,以M为单位的相似注意图比例大于90%。这表明之后学习到的注意图是相似的,变压器块可能退化为MLP。
因此,进一步叠加此类退化MHSA可能会引入模型秩退化问题(即,将分层参数相乘产生的模型参数张量秩将降低),并限制模型学习能力。我们对学习特征退化的分析也验证了这一点,如下所示。这种观察到的注意力崩溃可能是VIT观察到的表现饱和的原因之一。为了进一步验证不同深度的VIT是否存在这种现象,我们分别对12、16、24和32个变压器块的VIT进行了相同的实验,并计算了具有相似注意图的块数。图3(b)所示的结果清楚地表明,当添加更多transformer块时,相似注意力图块的数量与块总数的比率会增加
为了了解注意力崩溃如何影响ViT模型的性能,我们进一步研究了它如何影响更深层次的特征学习。对于特定的32块ViT模型,我们通过研究其余弦相似性,将最终输出特征与每个中间变压器块的输出进行比较。
图4中的结果表明,相似度非常高,赢得的特征在第20个块之后停止演化。注意相似度的增加与特征相似度的增加有密切的相关性。这一观察结果表明,注意力崩溃是导致VIT不可扩展问题的原因。
模型对比(将原始的self-Attention 换为了 Re-Attention)
Re-attention 模块
图7:(左):最初的自我注意机制;(右):我们提出的重新关注机制。
如图所示,原始注意力图在与值相乘之前通过可学习矩阵Θ进行混合(self.reatten_matrix = nn.Conv2d(self.num_heads,self.num_heads, 1, 1))。
效果可视化
图6:带有32个transformer块的基线ViT模型所选块的注意力地图可视化。
第一行基于原始的自我注意模块,第二行基于重新注意。可以看出,该模型只学习其浅块的局部面片关系,其余的注意值接近于零。虽然他们的注意力范围随着区块的加深而逐渐增大,但注意力地图趋于一致,因此失去了多样性。添加重新注意后,原来相似的注意图将更改为不同的,如第二行所示。只有在最后一个块的注意力图上,才会学习到一个几乎一致的注意力图。
ReAttention 代码
class ReAttention(nn.Module):
"""
It is observed that similarity along same batch of data is extremely large.
Thus can reduce the bs dimension when calculating the attention map.
"""
def __init__(self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0., proj_drop=0.,expansion_ratio = 3, apply_transform=True, transform_scale=False):
super().__init__()
self.num_heads = num_heads
head_dim = dim // num_heads
self.apply_transform = apply_transform
# NOTE scale factor was wrong in my original version, can set manually to be compat with prev weights
self.scale = qk_scale or head_dim ** -0.5
if apply_transform:
self.reatten_matrix = nn.Conv2d(self.num_heads,self.num_heads, 1, 1)
self.var_norm = nn.BatchNorm2d(self.num_heads)
self.qkv = nn.Linear(dim, dim * expansion_ratio, bias=qkv_bias)
self.reatten_scale = self.scale if transform_scale else 1.0
else:
self.qkv = nn.Linear(dim, dim * expansion_ratio, bias=qkv_bias)
self.attn_drop = nn.Dropout(attn_drop)
self.proj = nn.Linear(dim, dim)
self.proj_drop = nn.Dropout(proj_drop)
def forward(self, x, atten=None):
B, N, C = x.shape
# x = self.fc(x)
qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
q, k, v = qkv[0], qkv[1], qkv[2] # make torchscript happy (cannot use tensor as tuple)
attn = (q @ k.transpose(-2, -1)) * self.scale
attn = attn.softmax(dim=-1)
attn = self.attn_drop(attn)
if self.apply_transform:
attn = self.var_norm(self.reatten_matrix(attn)) * self.reatten_scale
attn_next = attn
x = (attn @ v).transpose(1, 2).reshape(B, N, C)
x = self.proj(x)
x = self.proj_drop(x)
return x, attn_next
ReAttention页面的所有代码
import torch
import torch.nn as nn
import numpy as np
from functools import partial
import torch.nn.init as init
import torch.nn.functional as F
from timm.models.layers import DropPath, to_2tuple
class Mlp(nn.Module):
def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0., expansion_ratio=3):
super().__init__()
out_features = out_features or in_features
hidden_features = hidden_features or in_features
self.fc1 = nn.Linear(in_features, hidden_features)
self.fc2 = nn.Linear(hidden_features, out_features)
self.act = act_layer()
self.drop = nn.Dropout(drop)
def forward(self, x):
x = self.fc1(x)
x = self.act(x)
x = self.drop(x)
x = self.fc2(x)
x = self.drop(x)
return x
class Attention(nn.Module):
def __init__(self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0., proj_drop=0., expansion_ratio=3):
super().__init__()
self.num_heads = num_heads
head_dim = dim // num_heads
self.expansion = expansion_ratio
# NOTE scale factor was wrong in my original version, can set manually to be compat with prev weights
self.scale = qk_scale or head_dim ** -0.5
self.qkv = nn.Linear(dim, dim * self.expansion, bias=qkv_bias)
self.attn_drop = nn.Dropout(attn_drop)
self.proj = nn.Linear(dim, dim)
self.proj_drop = nn.Dropout(proj_drop)
def forward(self, x, atten=None):
B, N, C = x.shape
qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
q, k, v = qkv[0], qkv[1], qkv[2] # make torchscript happy (cannot use tensor as tuple)
attn = (q @ k.transpose(-2, -1)) * self.scale
attn = attn.softmax(dim=-1)
attn = self.attn_drop(attn)
x = (attn @ v).transpose(1, 2).reshape(B, N, C)
x = self.proj(x)
x = self.proj_drop(x)
return x, attn
class ReAttention(nn.Module):
"""
It is observed that similarity along same batch of data is extremely large.
Thus can reduce the bs dimension when calculating the attention map.
"""
def __init__(self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0., proj_drop=0.,expansion_ratio = 3, apply_transform=True, transform_scale=False):
super().__init__()
self.num_heads = num_heads
head_dim = dim // num_heads
self.apply_transform = apply_transform
# NOTE scale factor was wrong in my original version, can set manually to be compat with prev weights
self.scale = qk_scale or head_dim ** -0.5
if apply_transform:
self.reatten_matrix = nn.Conv2d(self.num_heads,self.num_heads, 1, 1)
self.var_norm = nn.BatchNorm2d(self.num_heads)
self.qkv = nn.Linear(dim, dim * expansion_ratio, bias=qkv_bias)
self.reatten_scale = self.scale if transform_scale else 1.0
else:
self.qkv = nn.Linear(dim, dim * expansion_ratio, bias=qkv_bias)
self.attn_drop = nn.Dropout(attn_drop)
self.proj = nn.Linear(dim, dim)
self.proj_drop = nn.Dropout(proj_drop)
def forward(self, x, atten=None):
B, N, C = x.shape
# x = self.fc(x)
qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
q, k, v = qkv[0], qkv[1], qkv[2] # make torchscript happy (cannot use tensor as tuple)
attn = (q @ k.transpose(-2, -1)) * self.scale
attn = attn.softmax(dim=-1)
attn = self.attn_drop(attn)
if self.apply_transform:
attn = self.var_norm(self.reatten_matrix(attn)) * self.reatten_scale
attn_next = attn
x = (attn @ v).transpose(1, 2).reshape(B, N, C)
x = self.proj(x)
x = self.proj_drop(x)
return x, attn_next
class Block(nn.Module):
def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop=0., attn_drop=0.,
drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm, expansion=3,
group = False, share = False, re_atten=False, bs=False, apply_transform=False,
scale_adjustment=1.0, transform_scale=False):
super().__init__()
self.norm1 = norm_layer(dim)
self.re_atten = re_atten
self.adjust_ratio = scale_adjustment
self.dim = dim
if self.re_atten:
self.attn = ReAttention(
dim, num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop,
expansion_ratio = expansion, apply_transform=apply_transform, transform_scale=transform_scale)
else:
self.attn = Attention(
dim, num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop,
expansion_ratio = expansion)
# NOTE: drop path for stochastic depth, we shall see if this is better than dropout here
self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
self.norm2 = norm_layer(dim)
mlp_hidden_dim = int(dim * mlp_ratio)
self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)
def forward(self, x, atten=None):
if self.re_atten:
x_new, atten = self.attn(self.norm1(x * self.adjust_ratio), atten)
x = x + self.drop_path(x_new/self.adjust_ratio)
x = x + self.drop_path(self.mlp(self.norm2(x * self.adjust_ratio))) / self.adjust_ratio
return x, atten
else:
x_new, atten = self.attn(self.norm1(x), atten)
x= x + self.drop_path(x_new)
x = x + self.drop_path(self.mlp(self.norm2(x)))
return x, atten
class PatchEmbed_CNN(nn.Module):
"""
Following T2T, we use 3 layers of CNN for comparison with other methods.
"""
def __init__(self, img_size=224, patch_size=16, in_chans=3, embed_dim=768,spp=32):
super().__init__()
new_patch_size = to_2tuple(patch_size // 2)
img_size = to_2tuple(img_size)
patch_size = to_2tuple(patch_size)
num_patches = (img_size[1] // patch_size[1]) * (img_size[0] // patch_size[0])
self.img_size = img_size
self.patch_size = patch_size
self.num_patches = num_patches
self.conv1 = nn.Conv2d(in_chans, 64, kernel_size=7, stride=2, padding=3, bias=False) # 112x112
self.bn1 = nn.BatchNorm2d(64)
self.relu = nn.ReLU(inplace=True)
self.conv2 = nn.Conv2d(64, 64, kernel_size=3, stride=1, padding=1, bias=False) # 112x112
self.bn2 = nn.BatchNorm2d(64)
self.proj = nn.Conv2d(64, embed_dim, kernel_size=new_patch_size, stride=new_patch_size)
def forward(self, x):
x = self.conv1(x)
x = self.bn1(x)
x = self.relu(x)
x = self.conv2(x)
x = self.bn2(x)
x = self.relu(x)
x = self.proj(x).flatten(2).transpose(1, 2) # [B, C, W, H]
return x
class PatchEmbed(nn.Module):
"""
Same embedding as timm lib.
"""
def __init__(self, img_size=224, patch_size=16, in_chans=3, embed_dim=768):
super().__init__()
img_size = to_2tuple(img_size)
patch_size = to_2tuple(patch_size)
num_patches = (img_size[1] // patch_size[1]) * (img_size[0] // patch_size[0])
self.img_size = img_size
self.patch_size = patch_size
self.num_patches = num_patches
self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size)
def forward(self, x):
B, C, H, W = x.shape
# FIXME look at relaxing size constraints
assert H == self.img_size[0] and W == self.img_size[1], \
f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})."
x = self.proj(x).flatten(2).transpose(1, 2)
return x
class HybridEmbed(nn.Module):
"""
Same embedding as timm lib.
"""
def __init__(self, backbone, img_size=224, feature_size=None, in_chans=3, embed_dim=768):
super().__init__()
assert isinstance(backbone, nn.Module)
img_size = to_2tuple(img_size)
self.img_size = img_size
self.backbone = backbone
if feature_size is None:
with torch.no_grad():
training = backbone.training
if training:
backbone.eval()
o = self.backbone(torch.zeros(1, in_chans, img_size[0], img_size[1]))[-1]
feature_size = o.shape[-2:]
feature_dim = o.shape[1]
backbone.train(training)
else:
feature_size = to_2tuple(feature_size)
feature_dim = self.backbone.feature_info.channels()[-1]
self.num_patches = feature_size[0] * feature_size[1]
self.proj = nn.Linear(feature_dim, embed_dim)
def forward(self, x):
x = self.backbone(x)[-1]
x = x.flatten(2).transpose(1, 2)
x = self.proj(x)
return x