AN IMAGE IS WORTH 16X16 WORDS: TRANSFORMERS FOR IMAGE RECOGNITION AT SCALE(ICLR2021)
文章目录
实验部分详见原文,文章为原文翻译,如有错误请参照原文
摘要
- 挑战:
- 虽然Tansformer架构再NLP任务上已经成为实际标准,但是它对CV的应用仍存在限制
- 在视觉中,注意力结合着卷积网络一起应用,或者替换卷积网络的某一部分,同时保持他们的整体结构不变
- 方法:
- 我们证明了这种对CNN的依赖是不必要的,一种完全应用Tansformer直接将图片patches序列化可以在图形分类任务中性能更好
- 贡献:
- 当大量数据的预训练并转移到中小型网络(数据集)的识别基准中时(ImageNet, CIFAR-100, VTAB, etc.),ViT对比SOTA卷积网络获得优异的结果,同时需要更少的计算资源训练
- 预训练模型,微调code:https://github.com/google-research/vision_transformer
1.引言
1st Para: 背景
- 基于自注意力的架构的Transformer(Vaswani et al., 2017)已经变成NLP的代表性模型
- 这项领域研究是在大的语料库上预训练,然后在小的特定任务的数据集上进行微调
- 由于Transformer高效的计算效率与可拓展性,可以在未见过的size上训练超过100B参数的模型
- 随着模型与数据的发展,仍然没有饱和性能的迹象
2st Para: 挑战
- 然而,在CV中,卷积架构仍保持统治地位
- 受NLP的成功所鼓舞,多项研究在试图整合自注意力的类CNN架构,有一些完全代替了卷积
- 在后来的模型中,虽然理论有效,但是由于使用了特定注意力模式仍没有在现代硬件加速器(如GPU)上有效的拓展
- 因此,在大规模的图像识别中,经典的ResNet架构仍保持SOTA
3st Para: 方法
- 受Transformer在NLP领域扩展的启发,我们进行了实验将标准Transformer做最小的修改直接用在图片上
- 为了这么做,我们将图片分成多个patches并将patches的线性嵌入序列化作为Transformer的输入
- 图像patches相当于NLP应用的tokens
- 我们用有监督方式在图像分类任务上训练模型
4st Para: 存在的问题
- 在中等规模数据集(如ImageNet)上不用强正则化训练时,这些模型准确率相比于RexNets低了几个百分点
- 这中看似令人沮丧的结果可能是预料之中的,Transformers相比于CNNs缺乏归纳偏好(由网络自身的结构、参数设置等因素赋予网络特有的能力),例如平移等变性和局部性(平移等变性指图像中物体的位置变化不影响CNN的识别能力,局部性指CNN更关注图像中的局部特征而非整体图像),因此在训练不够充足的数据时没有很好的泛化性
5st Para: 贡献
- 然而,如果模型在更大的数据集(14M-300M图像)上训练,图像就会发生变化
- 我们发现大规模训练相比于归纳偏好更加重要
- 我们的ViT在充足规模的预训练与转移到数据点较少的任务上获得了优异的效果
- 在ImageNet-21K与JFT-300M的预训练中,ViT在多种图像识别数据集中超越了SOTA
- 尤其是,最好的模型在ImageNet达到了88.55%,ImageNet-ReaL达到90.72%,CIFAR-100达到94.55%,77.63% on the VTAB suite of 19 tasks
2.相关工作
3.方法
- 在模型设计中,我们尽可能接近原始Transformer
- 这种有意的简单的设置的有点是可伸缩性NLP的Transformer架构与有效的实现,可以做到开箱即用
3.1.VISION TRANSFORMER (VIT)
Figure 1:Model overview. We split an image into fixed-size patches, linearly embed each of them, add position embeddings, and feed the resulting sequence of vectors to a standard Transformer encoder. In order to perform classification, we use the standard approach of adding an extra learnable “classification token” to the sequence. The illustration of the Transformer encoder was inspired by Vaswani et al. (2017).
1st Para: 方法描述
- 模型概括在Figure 1.
- 标准Transformer将token embedding的 1D序列作为输入
- 为了处理 2D图像,我们将图像 x ∈ R H × W × C \mathbf{x}\in\mathbb{R}^{H\times W\times C} x∈RH×W×C reshape为展平的 2D patches的序列 x p ∈ r N × ( P 2 ⋅ C ) \mathbf{x}_p\in\mathbb{r}^{N\times(P^2\cdot C)} xp∈rN×(P2⋅C) ,其中 ( H , W ) (H,W) (H,W) 是原始图像的分辨率, C C C 为通道数, ( P , P ) (P,P) (P,P)