论文名称:AN IMAGE IS WORTH 16X16 WORDS:
TRANSFORMERS FOR IMAGE RECOGNITION AT SCALE
这篇博客参考csdn上一篇博客以及自己的理解总结:Vision Transformer详解_太阳花的小绿豆的博客-CSDN博客_vision transformer
首先作者在论文中提到,在中等大小的数据集上进行训练时,他们提出的方式没有ResNets的效果好,这是因为运用transformer的思想处理图片缺少了CNN中的inductive bias,CNN具有两种归纳偏置:
- 一种是局部性(locality/two-dimensional neighborhood structure),即图片上相邻的区域具有相似的特征;
- 一种是平移不变形(translation equivariance):
其中g代表卷积操作,f代表平移操作。当CNN具有以上两种归纳偏置,就有了很多先验信息,需要相对少的数据就可以学习一个比较好的模型.
但是当训练数据足够大的时候,Vit的表现就会超过CNN,突破了归纳偏置的限制。
模型结构图:
根据图片内容可以了解,图片最重要的部分分别是:
- Embedding部分(patch+position)
- Encoder部分(右图)
- MLP head
Embedding部分(Linear Projection of Flattened Patches)
标准的Transformer所接收的输入是一个1D的序列,然而图像数据的维度是H*D*C。因此,我们要对输入的图片数据进行转变。
如图,首先将输入的图片进行拆分,对于Vit-B/16,输入图片尺寸为224*224,取每个patch的大小为16*16*3,因此划分后得到 (224/16)*(224/16) = 196 个patch。然后想通过线性映射,把每个patch拉直,变成一个1D的sequence,这样就可以把每个patch看成一个单词,维度是1*768。
具体实现中,类似CNN的卷积操作,输入维度是224*224*3,卷积核维度为211*211*768,则输出的维度为14*14*768,将W,H拉直,则变为196*768。
在一个图片分成patch后经过线性层处理之后得到很多token,在这些token的前面加一个class_token,这部分参考了Bert中的做法,这个token的维度和其他的token一样,也是768维,将其和后面的token进行拼接,得到的最终结果:1*768+196*768=197*768。
而且这个class_token是可学习的,它专门用来分类。
Transformer中由于注意力机制计算中是每两个patch之间进行计算关系,这样两两计算得到的结果不包含位置信息,然而图片patch的位置信息也是很重要的,因此解决方式就是在每个token上加上了位置编码。这个和上一个论文的做法相同,位置编码是直接加到图片信息上的,因此维度也是 197*768。
Encoder部分
这部分可以参考我之前的博客:
https://mp.csdn.net/mp_blog/creation/editor/128945199
总结一下,也就是这个公式:
Multi-Head Attention
这部分可以参考我之前的博客:
https://mp.csdn.net/mp_blog/creation/editor/128945199
以两个头为例:
然后对Q和K进行相乘再除以根号下d得到b11,b12,b21,b22。
把b11和b12,b21和b22分别进行拼接(把两个头进行合并)得到b1和b2
注意这里的两个头进行合并部分的Wo是可学习的参数。
MLP block
如图左侧所示,就是全连接+GELU激活函数+Dropout组成也非常简单,需要注意的是第一个全连接层会把输入节点个数翻4倍:
[197, 768] -> [197, 3072]
,
第二个全连接层会还原回原节点个数:
[197, 3072] -> [197, 768]
。
MLP head
上面通过Transformer Encoder后输出的shape和输入的shape是保持不变的,以ViT-B/16为例,输入的是[197, 768]输出的还是[197, 768]。注意,在Transformer Encoder后其实还有一个Layer Norm没有画出来,后面有我自己画的ViT的模型可以看到详细结构。这里我们只是需要分类的信息,所以我们只需要提取出[class]token生成的对应结果就行,即[197, 768]中抽取出[class]token对应的[1, 768]。接着我们通过MLP Head得到我们最终的分类结果。MLP Head原论文中说在训练ImageNet21K时是由Linear+tanh激活函数+Linear组成。但是迁移到ImageNet1K上或者你自己的数据上时,只用一个Linear即可。