Vision Transformer (ViT) 是一个基于 Transformer 架构的计算机视觉模型,它将 Transformer 直接应用于图像分类任务,跳脱了传统卷积神经网络(CNN)的框架
ViT的核心思想是将图像转化为一维的序列数据,并将其输入到Transformer中进行处理,最终进行分类。以下是ViT模型的详细架构解析
目录
1.1 Linear Projection of Flattened Patches 模块
1 模型架构
Vision Transformer (ViT) 受到了 Transformer 在自然语言处理(NLP)领域的成功启发,尝试将 Transformer 架构直接应用于计算机视觉任务,特别是图像分类任务。与传统的卷积神经网络(CNN)不同,ViT 在设计时尽量保留了 Transformer 的原始结构,并进行最小的修改,目的是能够开箱即用地利用 Transformer 在 NLP 领域的强大体系结构和高效实现
1.1 Linear Projection of Flattened Patches 模块
1.1.1 图像分块
在图像处理任务中,ViT 的一个关键创新是如何将二维图像转化为适合 Transformer 处理的序列数据。标准的 Transformer 架构适合处理序列数据,而图像是二维数据,因此需要将图像转换为一维序列。为此,ViT将输入图像切分为多个小块(patches)
假设输入图像的分辨率为224×224像素,ViT 将该图像切分成多个16×16像素的小块。每个小块(patch)被认为是一个输入元素,类似于 NLP 中的单词或子词。对于一个224×224的图像,切分后的图像块数量为:
即总共有196个小块(每个块的尺寸为16×16像素)
1.1.2 扁平化和线性变换
每个图像块被扁平化为一维向量,其维度为16×16×3=768,其中16×16是每个图像块的像素数,3是 RGB 通道数。因此,每个小块的嵌入向量长度为768。然后,这些扁平化的小块通过一个线性嵌入层被映射到一个更高维度的空间,以便传入Transformer模型
通过将图像切分成小块,ViT 成功地将图像转化为了一个序列数据,符合 Transformer 的输入要求。这样做的好处之一是能够控制输入序列的长度,避免原始图像的高分辨率带来的计算复杂度过高的问题。对于每个图像块,维度是768(16×16×3),而整个序列的长度是196(14×14),因此最终输入Transformer的序列长度为196
这对于大多数计算机来说是一个可以接受的长度,并且这种序列长度可以在合理的计算资源范围内进行训练。这是ViT能够在资源有限的情况下仍然表现出色的原因之一
1.1.3 添加位置编码
在 ViT 中,模型处理图像时,采用了自注意力机制来捕获图像块之间的关系。由于 Transformer 的注意力机制要求输入的元素之间进行两两计算,因此输入顺序并不会影响注意力矩阵的计算结果。然而,图像本身是一个具有空间结构的二维数据,图像块之间的相对位置是非常重要的,因此不能随意地组合图像块
Transformer模型并没有内置位置感知能力,因此必须为每个图像块添加位置编码,以保持图像中每个块的空间位置信息
与 BERT 中的做法类似,ViT 也在输入序列的最前面添加了一个 CLS词元(Classification Token),这个词元的作用是作为整个图像的全局表示,并最终用于分类
CLS词元的作用
- 学习的词元:CLS词元是一个可以学习的特殊词元,它代表了整个输入图像的全局信息。通过这种方式,Transformer模型可以在训练过程中通过CLS词元提取出图像的全局特征
- 编码器作用:由于 ViT 只使用了 Transformer 的编码器,没有解码器部分,所以 CLS 词元在这里充当了“解码器输出”的角色。它类似于标准 Transformer 解码器中的右移操作,用于指示模型开始进行图像的全局特征提取
最终输入的矩阵维度
输入 Transformer 模型的矩阵维度是(196 + 1) × 768,即196个图像块加上1个CLS词元(总共197个元素)。通过这种方式,CLS词元不仅起到了一个特殊标记的作用,而且能在训练过程中提取整个图像的全局特征
1.2 Transformer 编码器
ViT的结构与传统的 Transformer 编码器类似,细节可参考之前笔记文章,但有一些针对图像处理的特殊调整,每一层操作如下:
- 标准化层:首先对输入进行标准化,以提高训练的稳定性
- 多头注意力层(Multi-Head Attention):在这一层中,所有输入元素之间(图像块之间)的相似度都会被计算出来,并利用注意力机制加权合并信息
- 残差连接:为了避免深层网络训练中的梯度消失问题,Transformer 采用了残差连接,使得每一层的输出都与输入相加,帮助梯度更有效地传播
- 标准化层和 MLP 层:每一层的输出还会经过一个标准化层,之后是一个前馈神经网络(MLP),用于进一步处理特征
- 再次的残差连接:每一层的输出还会通过一个残差连接,以便保留初始输入的信息
这一系列的操作保证了模型能够有效地捕获图像块之间的复杂依赖关系
1. 多头注意力机制
ViT 使用多头注意力机制(Multi-Head Attention)来捕获图像块之间的上下文关系。每个注意力头都在不同的子空间中学习图像块之间的关系,因此能够捕获不同层次的特征。多头注意力的工作方式如下:
- 在网络的初期,多个注意力头之间的关注范围较近,这意味着模型在学习时会更多关注局部区域的信息。
- 随着网络的深度增加,注意力头之间的距离逐渐增大,这说明模型开始关注图像中的全局信息,并能够整合更多远距离像素之间的关系。
这种注意力机制的工作原理类似于卷积神经网络中的感受野,通过对图像不同区域之间的注意力加权,模型可以逐步整合整个图像的信息
1.3 MLP Head 分类头
经过若干层 Transformer 编码器的处理后,得到的类嵌入向量包含了整个输入图像的信息,并被送入分类头(通常是一个全连接层),用于最终的图像分类任务
MLP 由多个全连接层和激活函数组成,经过一系列线性变换和非线性激活后,生成用于分类的输出向量,最后通过 Softmax 函数转换为类别概率分布,完成最终的分类任务
在 ViT 中,通常使用 [CLS] 标记作为图像的表示。通过给定一个额外的 CLS 标记(类似于 BERT中的 CLS token),将其与其他图像块的表示一同输入到 Transformer 中,最终 CLS token 的输出将作为整个图像的全局特征来进行分类
- 在Transformer模型的输出阶段,CLS词元会包含整个图像的全局信息
- 使用切片操作提取出第一个 CLS 词元,其维度为1×768。这就是我们用于分类的图像表示
2 ViT 预训练
2.1 模型性能与数据集大小的关系
随着 Transformer 模型的规模逐渐扩大,尤其是在计算机视觉任务中应用 Transformer 时,研究者并没有看到性能趋于饱和的迹象。这个现象非常有趣,因为在许多机器学习任务中,扩大模型或增加数据集并不总是能带来持续的性能提升,甚至可能导致过拟合问题。然而,Transformer模型,尤其是 ViT,在处理更大数据集时似乎能够继续提升性能
通常,当我们扩大模型(例如增加模型的层数、参数量或节点数)时,可能会面临过拟合的风险,尤其是在数据量较小的情况下。当模型变得非常复杂时,它可能会“记住”训练集中的噪声和细节,而不是真正的规律,导致模型在训练集上表现得很好,但在测试集上却表现不佳
然而,对于 ViT 模型,这种现象似乎并不明显。ViT 模型在处理大规模数据集时,仍然展现出显著的性能提升,并没有出现传统模型中容易见到的饱和或过拟合现象
2.2 ViT 训练数据集
ViT 的训练依赖于不同规模的数据集,主要包括以下几个数据集:
- ImageNet-1k:包含1000个类别和约130万张图像。这个数据集广泛用于图像分类任务,是一个中等规模的标准数据集
- ImageNet-21k:包含21000个类别和1400万张图像。这个数据集在规模上比 ImageNet-1k 要大得多,适用于更高效的预训练
- JFT-300M:这是谷歌自己收集的一个数据集,包含18000个类别和3亿多张图像。它是一个超大规模的数据集,适用于进行大规模的深度学习预训练
2.3 ViT 数据集规模实验
为了研究 ViT 模型对不同数据集规模的敏感性,原作者们进行了相关实验。论文中实验结果显示,数据集的大小对 ViT 的性能有显著影响,尤其是在较小的数据集上,ViT 的效果表现不如其他传统的卷积神经网络(ResNet)
2.3.1 实验结果
如下图所示,横轴表示不同大小的数据集,纵轴表示模型在 ImageNet 测试集上的准确率。图中的不同点代表了不同规模的 ViT 和 ResNet 模型,其中 “BiT” 代表不同大小的 ResNet 模型,其他不同颜色和大小的圆点则表示不同规模的 ViT 模型。
- 在 ImageNet-1k 上预训练时,ViT 的效果基本上不如 BiT(即 ResNet 模型)。这是因为 ViT 在训练时并没有卷积神经网络(CNN)那样的局部感知先验,而CNN通过卷积核能够更好地捕捉局部特征,因此在较小的数据集上表现更好
- 在 ImageNet-21k 上预训练时,ViT 的性能明显提升,并且接近 BiT 的性能。这表明随着数据集规模的增大,ViT 能够通过更多的数据来学习图像的全局信息,弥补其缺乏局部先验的不足。
- 在 JFT-300M 数据集上预训练时,ViT 的表现明显超过了 BiT。JFT 数据集的巨大规模使得 ViT 能够充分挖掘图像中的复杂特征,弥补了传统CNN在大数据集上的性能瓶颈
2.3.2 ViT的优势与数据集规模的关系
数据集的规模对于Transformer的重要性:
- 小数据集:在小型数据集(如 ImageNet-1k)上,ViT 的表现通常不如具有局部先验的 CNN (ResNet),这是因为 ViT 在小数据集上无法充分发挥其全局建模的优势
- 中等规模的数据集:在中等规模的数据集(如 ImageNet-21k)上,ViT 的表现接近 ResNet,这表明随着数据量的增加,Transformer 开始展现出与CNN类似的效果
- 大数据集:在大规模数据集(如 JFT-300M)上,ViT的全局建模能力得到了充分的发挥,能够比传统 CNN 模型更好地捕捉复杂的图像特征,因此其性能超越了 ResNet 等经典 CNN 模型
从上述实验结果可以看出,ViT 的优势随着训练数据集的规模增大而逐步展现,尤其在超大规模的数据集(如 JFT-300M)上,ViT 能够发挥其最大的优势,超过传统的 CNN模 型(如ResNet)。这种现象表明,在大规模数据集上训练 ViT 时,模型能够充分利用其自注意力机制的全局建模能力,从而获得更好的泛化性能
实验结果提供了两个重要的信息:
-
数据集大小的建议:如果你的数据集规模较小(如 ImageNet-1k),那么使用卷积神经网络会更加合适,因为 CNN 已经具备了处理小数据集时所需的先验知识,能够在小数据集上表现得更好
-
数据集足够大时的优势:如果你的数据集非常大(如 ImageNet-21k 或 JFT-300M),Transformer 模型(如ViT)将能充分发挥其优势,且其性能可与最好的卷积神经网络(如ResNet)相媲美,甚至超越它们。ViT的可扩展性和训练效率相比CNN在大规模数据集上更具优势
2.4 ViT 与 CNN 对比与选择
当在中型数据集(如 ImageNet)上训练 ViT 时,如果不加上一些强约束,ViT 的效果通常会逊色于同等规模的 ResNet。但是,当在大型数据集(如 JFT-300M)上训练 ViT 时,即使没有强约束,ViT 也能够取得与 ResNet 相近,甚至超越它的性能。这一现象可以归结为Transformer 模型和卷积神经网络之间的一个核心区别:先验知识的差异
2.4.1 CNN 中的先验知识
CNN 利用了两种重要的先验知识,使得它在小数据集上训练时也能得到较好的结果:
-
局部特征性:CNN 通过滑动窗口的形式来提取图像的特征,即每个卷积核只关注图像中的局部区域。这意味着卷积神经网络天然假设图像中相邻的像素区域可能具有相似的特征(例如,边缘、纹理等)。这种局部特征性让 CNN 能在图像的每个区域上提取到局部特征,有助于它在图像分类、物体检测等任务中的表现
-
平移不变性:卷积操作本身具有平移不变性,即无论图像中的对象在空间上如何平移,卷积神经网络都会对其进行相同的特征提取。因此,卷积神经网络能够在图像中处理各种平移变换,保持一定的鲁棒性
这两个先验知识使得卷积神经网络在面对相对较小的数据集时,能够凭借其基础结构表现得非常好,因为它不需要从零开始学习这些基本的图像特性,已经有了强大的结构性假设
2.4.2 Transformer 缺乏先验知识
与卷积神经网络不同,Transformer模型并没有这些局部特征性和平移不变性的先验知识
Transformer 模型依赖自注意力机制来建模输入序列中各个元素之间的关系,它没有任何关于图像局部结构或空间位置信息的内建假设。因此,Transformer 在训练时必须依赖大量的数据来学习这些视觉世界的感知能力
在大规模数据集上,ViT 能够通过自注意力机制高效地捕捉图像的全局信息,因此具有比传统卷积神经网络更好的可扩展性和训练效率。这一点体现在如下两个方面:
-
自注意力机制:Transformer 通过自注意力机制,能够将图像中每个像素点的信息与其他所有像素点的信息进行交互,从而捕捉到整个图像的全局依赖关系。这对于大规模数据集来说非常重要,因为数据集中的图像往往包含了大量的空间信息和长程依赖
-
训练效率:Transformer 模型的参数和结构设计使得它在大规模数据集上进行训练时能够更高效地利用数据,因此 ViT 在较大的数据集上表现得尤为出色。
2.4.3 混合架构:结合 CNN 和 Transformer 的优势
既然 CNN 和 Transformer 模型各有优缺点,那么混合架构自然成为一个值得探索的方向。这样的混合架构可以充分利用 CNN 在局部特征提取和平移不变性方面的优势,同时也能够结合 Transformer 在全局依赖建模方面的强大能力
混合架构的设计:
-
卷积神经网络(CNN)作为特征提取器:可以使用一个较小的CNN(如ResNet)作为特征提取器,从原始图像中提取局部特征。CNN擅长捕捉图像中的低级特征(如边缘、纹理、颜色等),并且能够保留平移不变性的信息。
-
Transformer处理全局特征:将CNN提取的特征图作为输入,送入Transformer模型中。Transformer可以有效地建模图像块之间的全局依赖关系,捕捉物体的位置、大小、背景等高层次的语义信息
在后续的 Diffusion Policy 中,进一步详细分析此内容