图1 vit的架构图
从图1可以看出来,vit是将输入图片分割成一个一个patch,这里是patch的大小是16*16,然后将这些patch拉平并通过线性变换映射为固定长度的向量,然后会在这些向量的最前面加上一个用于分类的token(因为vit只是想将transformer用在图像分类上面,这个token叫做[CLS] token),然后每个patch还会再加上一个可以学习的位置编码,然后输入到 Transformer Encoder 中。最后,提取出 [CLS] token 对应的特征输入到分类器中分类。
上述是vit的一次前向的过程,接下来讲一下具体的实现。
假设输入的是一张224*224*3的图片,经过一次卷积(卷积核的大小是16*16,步长是16),经过这次卷积我们可以得到14*14*768的特征图(像素的个数不变,相当于把每个patch拉直在了一个像素点里面,所以是16*16*3=768),接下来我们把特征图拉直就得到196*768的矩阵,然后concat一个cls token,之后把位置编码直接与向量相加,在这里我们得到的是197*768的矩阵,然后就可以送到encode里面了。然后对得到的数据进行切片(因为对于transformer来说,数据的维度是不变的),我们将cls token切出来送到mlp里面得到最后的结果
mlp一般来说就是一个全连接层,根据数据集大小可以自己设计
关于位置编码,vit使用的位置编码是学习得到的,这里补充一下比较一般的位置编码
pos是这个词在句子中的顺序,d是总维度数,i是维度是一半向下取整,比如第0维的时候i=0,第2维i=1,偶数位和奇数位指的是在维度中的位置是奇还是偶。
那这里还有一个问题就是加上位置编码之后,数据的信息不会被破坏吗?其实是不会的,因为有充足的训练数据,使得模型完全可以理解这种复合的信息。同样的,模型一样可以理解同一个词在不同位置的意思,比如Can you can a can as a cannner can a can。