An Image is worth 16x16 words:transformers for image recognition at scale

摘要

虽然Transformer架构已经是NLP领域的一个标准,但是应用transformer到CV领域效果还是很有限的。在视觉领域,自注意力要么和卷积神经网络一起使用,要么是将卷积神经网络里的卷积替换成自注意力。但是仍然保持整体结构不变。本文证明了这种对于卷积神经网络的依赖是完全不必要的,一个纯的transformer直接作用于一系列图像块的时候也可以在图像分类任务上表现很好。尤其是当我们在大规模的数据集上进行预训练,然后迁移到中小型数据集上使用时,Vision transformer能获得和最好的卷积神经网络相媲美的结果。这里我们将CIFAR-100,Imagenet、VATB当作小数据集。

1 引言

自注意力机制网络,尤其是Transformer已经是自然语言处理领域的必选模型了。现在比较主流的方式是先去一个大规模的数据集上去做预训练,然后再在一些特定领域的小数据集上进行微调(fine-tune)。多亏了Transformer的计算高效性和可扩展性,现在已经可以训练超过1000亿参数的模型了。随着模型和数据集的增长,我们还没有看到任何性能饱和的现象。

这里介绍下将Transformer应用到视觉问题上的一些难处。先回顾下transformer,假设我们有如下的Transformer的Encoder以及一些输入元素(在自然语言处理中,这就是一个句子里面的一个一个的单词):
在这里插入图片描述
Transformer里面最主要的就是自注意力操作,自注意力操作就是每个元素都要和每个元素去做互动,然后计算得到一个Attention图,接下来用这个Attention图去进行加权平均,最后得到输出:
在这里插入图片描述
因为在做自注意力时我们是两两相互的,这个计算复杂度是和序列的长度成乘方关系 O ( n 2 ) O(n^2) O(n2),目前一般在自然语言处理领域,硬件能够支持的序列长度也就是几百或者上千。在BERT里,序列长度是512。

在视觉领域,首先要解决的任务是将2d图片转化为1d的序列。最直观的方式是把像素图拉直,但是复杂度太高 224 × 224 = 50176 224\times224=50176 224×224=50176

回到引言第二段,在视觉领域,卷积神经网络仍然占主导地位。但是由于到Transformer在NLP领域的成功,现在许多工作尝试把CNN结构和自注意力结构结合在一起,还有一些工作将整个卷积神经都替换掉了,全部使用自注意力。这些方法其实都是在降低序列长度。完全用自注意力替代卷积的这一类工作虽然理论上非常高效,但事实上由于他们的自注意力操作都是一些比较特殊的自注意力操作,没有应用到现有的硬件结构进行加速,所以很难训练一个较大的模型。因此,在大规模的图像识别上,传统的ResNet结构网络还是效果最好的。

作者被Transformer在NLP领域的可扩展性启发,他们希望直接将一个标准的Transformer应用于图片,尽量少做修改(不针对视觉任务进行特定的改变)。Vision Transformer将一个图片打成了很多 16 × 16 16\times 16 16×16的patch,此时宽度和高都是 224 / 16 = 14 224/16=14 224/16=14,最后序列长度就变成了196。Vision Transformer将每个patch作为一个元素,通过fc layer可以得到linear embedding,这些会被当作输入传递给transformer。我们可以把这些patch作为NLP里面的单词,训练Vision Transformer使用的是有监督训练。

当在中型大小的数据集上(比如ImageNet)进行训练时,如果不加比较强的约束(strong regularization),ViT和同等大小的残差网络相比,是要弱几个点的。这个看起来不太好的结果其实是可以预期的,因为Transformer和卷积神经网络相比,缺少一些卷积神经网络有的归纳偏置(inductive biases,这指一些先验知识,即我们提前做好的假设,对于卷积神经网络来说,我们常说的有两个inductive bias,一个叫做locality,由于卷积神经网络以滑动窗口的形式一点点在图片上进行卷积,所以假设图片上相邻的区域会有相邻的特征,另外一个inductive bias叫做平移等变性,translation equivariance, f ( g ( x ) ) = g ( f ( x ) ) f(g(x))=g(f(x)) f(g(x))=g(f(x)),可以把 g g g理解为平移, f f f理解为卷积),由于拥有这些inductive bias,卷积神经网络具有许多先验信息,所以可以应用相对少的数据来训练一个较好的模型,但是对于Transformer来说,没有这些先验信息,对于vision领域的知识全部需要自己学习。

为了验证这个假设,作者在更大的数据集上进行了预训练(14M指代ImageNet 22K数据集,300M指代Google的JFT 300M数据集。大规模的预训练表明优于归纳偏置。Vision Transformer只要在有足够数据进行预训练的情况下就能在下游任务上获得较好的迁移学习效果。在ImageNet 21k或者JFT-300M上进行训练时,ViT就能获得和现在最好的残差网络相近,或者说更好的结果。具体而言,在ImageNet上实现了88.55%,在ImageNet-ReaL上实现了90.72%,在CIFAR-100上实现了94.55%,在VTAB上实现了77.63%(这个数据集融合了19个数据集,主要用于测试鲁棒性)。

2 相关工作

24.39

3 方法

模型设计是尽可能贴近原始的transformer,这样做的好处是可以直接把NLP那边已经成功的Transformer架构直接拿过来用,不用自己再去魔改模型。而且因为Transformer已经在NLP领域火了这么久,现在有一些写得非常高效的实现,同样Vision Transformer可以直接拿过来用。

3.1 Vision transformer

模型总览图如图一所示:

在这里插入图片描述
标准的Transformer需要一系列1D序列作为输入,所以为了符合Transformer的结构,我们将图片 x ∈ R H × W × C x\in\mathbb{R}^{H\times W\times C} xRH×W×C变成一系列展平的2D patches x p ∈ R N × ( P 2 × C ) x_p\in\mathbb{R}^{N\times (P^2\times C)} xpRN×(P2×C),这里 ( H , W ) (H,W) (H,W)是原始图片的分辨率, C C C是通道数, ( P , P ) (P,P) (P,P)是每个图片patch的分辨率, N = H W / P 2 N=HW/P^2 N=HW/P2是patch的数量,这就是最终传入Transformer的有效序列长度。Transformer从头到尾都是使用 D D D作为向量长度(768),为了和Transformer的维度相匹配,所以我们的图像patch维度也设定为768(具体做法是使用了一个可以训练的linear projection,即全连接层)。从这个全连接层出来的向量我们称之为patch embedding。

为了进行最后的分类,作者借鉴了BERT里面的 [ c l a s s ] [class] [class] token,这个token是一个可以学习的特征,且和图像的特征具有相同的维度,token初始表示为 z 0 0 = x c l a s s z_0^0=x_{class} z00=xclass,经过多层Transformer处理后,我们将token表示为 z L 0 z_L^0 zL0,此时我们将这个token当成整个Transformer的输出,也就是当作整个图片的特征。在pre-training以及fine-tuning阶段,一个分类头都连接到了 z L 0 z_L^0 zL0。这个分类头是一个MLP,这个MLP在pre-training阶段有一个hidden layer,在fine-tuning阶段有一个linear layer。

位置编码信息被添加到patch embedding中来保留位置信息。本文使用标准的可以学习的1D position embedding,也就是BERT里面使用的位置编码。作者也尝试了其他编码形式(因为我们是针对图像任务),例如一个2D-aware的位置编码。实验显示最后的结果相差不大。

作者用公式描述了整体的过程:

z 0 = [ x c l a s s ; x p 1 E ; x p 2 E ; …   ; x p N E ] + E p o s , E ∈ R ( N + 1 ) × D z_0=[x_{class};x_p^1E;x_p^2E;\dots;x_p^NE]+E_{pos},\quad\quad E\in\mathbb{R}^{(N+1)\times D} z0=[xclass;xp1E;xp2E;;xpNE]+Epos,ER(N+1)×D

这里 x p 1 , x p 2 x_p^1,x_p^2 xp1,xp2等其实就是这些图像块中的patch,一共有 N N N个patch,每个patch先和linear projection(这里表示为 E E E)进行转换,从而得到patch embedding,在得到这些linear embedding后,我们在前面拼接一个class embedding,利用它得到最后的输出。在得到了所有的tokens后,我们需要对这些token进行位置编码,我们将位置编码信息 E p o s E_{pos} Epos直接加入矩阵中,此时 z 0 z_0 z0就是transformer的输入了。接下来是一个循环:

z ℓ ′ = MSA ( LN ( z ℓ − 1 ) ) + z ℓ − 1 ℓ = 1... L z ℓ = MLP ( LN ( z ℓ ′ ) ) + z ℓ ′ ℓ = 1... L z'_{\ell}=\text{MSA}(\text{LN}(z_{\ell-1}))+z_{\ell-1}\quad\quad\quad \ell=1...L\\ z_{\ell}=\text{MLP}(\text{LN}(z'_{\ell}))+z'_{\ell}\quad\quad\quad\quad\quad\ell=1...L z=MSA(LN(z1))+z1=1...Lz=MLP(LN(z))+z=1...L

对于每个Transformer block来说,里面都有两个操作,一个是MLP,另外一个是MSA(多头自注意力),在进行这两个操作前,我们要先经过layer norm(LN),然后每一层出来的结果都要去再经过一次残差连接。 z ℓ ′ z_{\ell}' z就是多头自注意力的结果, z ℓ z_{\ell} z就是每个Transformer block整体出来的结果。在l层循环后,我们将 z L 0 z_{L}^0 zL0,也就是最后一层的第一个token拿出来,当作整体图像的一个特征,从而去做最后的这个分类任务。

归纳偏置:Vision Transformer相比CNN而言少了许多图像特有的归纳偏置,比如CNN里面存在的locality和translation equivariance。在ViT中,只有MLP层是局部且平移等变性的。但是自注意力层是全局的。

混合结构:我们原来有一张图片,然后我们用Res50等结构去处理得到特征图( 14 × 14 14\times14 14×14),这时这个特征图也是196个元素,然后我们用新得到的这196个元素去和全连接层进行操作。得到新的patch embedding,其实这就是两种不同的对图片预处理的方式。

3.2

52分钟左右

其主要包括以下模块:

图片预处理:

作者将 x ∈ R H × W × C x\in\mathbb{R}^{H\times W\times C} xRH×W×C的图片,变成一个 x p ∈ R N × ( P 2 ⋅ C ) x_p\in\mathbb{R}^{N\times (P^2\cdot C)} xpRN×(P2C)的sequence of flattened 2D patches。这可以看做是一个2D块序列,序列中一共有 N = H W / P 2 N=HW/P^2 N=HW/P2个展平的2D块,每个块的维度是 ( P 2 ⋅ C ) (P^2\cdot C) (P2C),其中 P P P是块大小, C C C是通道数。

作者进行这步的意图是:因为Transformer希望输入是一个二维的矩阵 ( N , D ) (N,D) (N,D),其中 N N N是序列长度, D D D是序列中每个向量的维度(常用256)。所以这里我们也要设法将 H × W × C H\times W\times C H×W×C的三维图片转化为 ( N , D ) (N,D) (N,D)的二维输入。

对应代码是:

x = rearrange(img, 'b c (h p1) (w p2) -> b (h w) (p1 p2 c)', p1=p, p2=p)

具体使用了einops库,具体可以参考这篇文章

现在得到的向量维度是: x p ∈ N × ( P 2 ⋅ C ) x_p\in N\times (P^2\cdot C) xpN×(P2C),要转化成 ( N , D ) (N,D) (N,D)的二维输入,我们还需要进一步叫做Patch Embedding的步骤。

Patch Embedding

这步要做的是对每个向量都做一个线性变换(即全连接层),压缩后的维度为 D D D,我们称其为Patch Embedding。

z 0 = [ x c l a s s ;   x p 1 E ; x p 2 E ; …   ;   x p N E ] + E p o s , E ∈ R ( P 2 ⋅ C ) × D ,   E p o s ∈ R ( N + 1 ) × D z_0=[x_{class};\ x_p^1E; x_p^2E;\dots;\ x_p^NE]+E_{pos},\quad\quad E\in\mathbb{R}^{(P^2\cdot C)\times D},\ E_{pos}\in\mathbb{R}^{(N+1)\times D} z0=[xclass; xp1E;xp2E;; xpNE]+Epos,ER(P2C)×D, EposR(N+1)×D

全连接层就是上式中的 E E E,它的输入维度大小是 ( P 2 ⋅ C ) (P^2\cdot C) (P2C),输出维度大小是 D D D

注意上面式子中存在一个 x c l a s s x_{class} xclass,假设将整个图片切成9个块,但是最终输入到transformer中的是10个向量,这是人为增加的一个向量。

为什么要追加这个向量?

如果没有这个向量,假设 N = 9 N=9 N=9个向量输入transformer encoder,输出9个编码向量,然后呢?对于分类任务而言,我们应该用哪个输出向量进行后续分类呢?

所以我们干脆使用一个向量 x c l a s s ( v e c t o r , d i m = D ) x_{class}(vector,dim=D) xclass(vector,dim=D),这个向量是可学习的嵌入向量,它和那9个向量一起输入transformer encoder,输出1+9个编码向量,然后使用第0个编码向量,即 x c l a s s x_{class} xclass的输出进行分类预测即可。

这么做的原因可以理解为:ViT只用到了transformer的encoder,而并没有用到decoder,而 x c l a s s x_{class} xclass的作用有点类似于解码器中的Query的作用,相对应的Key,Value就是其他9个编码向量的输出。

x c l a s s x_{class} xclass是一个可学习的嵌入向量,它的意义说通俗一点为:寻找其他9个输入向量对应的image的类别。

代码为:

 dim=1024
self.cls_token = nn.Parameter(torch.randn(1, 1, dim))

# forward前向代码
# 变成(b,64,1024)
cls_tokens = repeat(self.cls_token, '() n d -> b n d', b=b)
# 跟前面的分块进行concat
# 额外追加token,变成b,65,1024
x = torch.cat((cls_tokens, x), dim=1)
Positional Encoding

按照transformer的位置编码方式,在本文中同样使用了位置编码。引入了一个positional encoder E p o s E_{pos} Epos 来加入位置信息,同样在这里引入了pos embedding,这是一个可训练的变量。

在ViT中,我们没有使用原版Transformer的sincos编码,而是直接设置为可学习的Positional Encoding,这两个的效果差不多。我们对训练好的pos embedding进行可视化,如下图所示。

在这里插入图片描述
可以发现,位置越接近,往往具有更相似的位置编码。此外,还出现了行列结构;同一行/列中的patch具有相似的位置编码。

代码表示如下:

# num_patches=64,dim=1024,+1是因为多了一个cls开启解码标志
self.pos_embedding = nn.Parameter(torch.randn(1, num_patches + 1, dim))

Transformer的前向过程:

z 0 = [ x c l a s s ;   x p 1 E ; x p 2 E ; …   ;   x p N E ] + E p o s , E ∈ R ( P 2 ⋅ C ) × D ,   E p o s ∈ R ( N + 1 ) × D z l ′ = MSA ( LN ( z l − 1 ) ) + z l − 1 ,   l = 1 … , L z l = MLP ( LN ( z l ′ ) ) + z l ′ , l = 1 … L y = LN ( z L 0 ) z_0=[x_{class};\ x_p^1E; x_p^2E;\dots;\ x_p^NE]+E_{pos},\quad\quad\quad\quad E\in\mathbb{R}^{(P^2\cdot C)\times D},\ E_{pos}\in\mathbb{R}^{(N+1)\times D}\\ z_l'=\text{MSA}(\text{LN}(z_{l-1}))+z_{l-1},\quad\quad\quad\quad\quad\quad\quad\ l=1\dots,L\\z_l=\text{MLP}(\text{LN}(z_l'))+z_l',\quad\quad\quad\quad\quad\quad\quad\quad\quad l=1\dots L\\y=\text{LN}(z_L^0)\quad\quad\quad\quad\quad\quad\quad\quad\quad\quad z0=[xclass; xp1E;xp2E;; xpNE]+Epos,ER(P2C)×D, EposR(N+1)×Dzl=MSA(LN(zl1))+zl1, l=1,Lzl=MLP(LN(zl))+zl,l=1Ly=LN(zL0)

其中,第一个式子为上面提及的Patch Embedding 和 Positional Encoding的过程。

第二个式子为Transformer Encoder的Multi-head Self-attention, Add and Norm的过程,重复L次。

第三个式子为Transformer Encoder的Feed Forward Network,Add and Norm的过程,重复L次。

作者采用的是没有任何改动的transformer。

最后一个是MLP的Classification Head,整体的结构只有这些,如下图所示(变量的维度变化过程标注在了图中):
在这里插入图片描述

4 实验

这个章节主要对比了ResNet,Vision Transformer以及混合模型的表征学习能力。为了了解训练好每个模型到底需要多少数据,我们在不同大小的数据集上进行预训练,然后在很多benchmark上进行测试。当考虑到预训练的代价,即预训练的时间长短时,ViT表现的非常好,能在大多数数据集上取得最好的结果。同时需要更少的时间去训练。最后作者还做了一个小小的自监督实验,结果还可以,说自监督的ViT还是比较有潜力的。

4.1 设置

数据集:作者使用了ILSVRC-2012 ImageNet数据集,同时使用了大家最普遍使用的这1000个类(称为ImageNet-1k),和更大规模的数据集(ImageNet-21k),作者还使用了JFT数据集(Google自己的数据集,包含了三亿张图片)。

5 结论

  • 6
    点赞
  • 27
    收藏
    觉得还不错? 一键收藏
  • 2
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 2
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值