An image is worth 16x16 words: transformers for image recognition at scale
Google Research, Brain Team
https://arxiv.org/abs/2010.11929
是啥
Transformer 已经是 NLP 领域中很火很香很重要的模型了,这篇文章把 Transformer 应用到 CV 任务上了,提出了 Vision Transformer (简称 ViT) 模型,它的效果超过了基于卷积神经网络的 SOTA 模型。
至此,NLP 和 CV 任务都可以使用相同模型架构来处理,也为后续多模态算法的发展提供了强有力的支持
架构
ViT 架构的总览图如下
运行流程如下
- 给定一张图,划分为多个 patch,上图为九宫格,也就是九个 patch
- 摊开变成了序列后输入到 linear Projection (全连接)获得 patch embedding (768 维)
- 图片每个 patch 有顺序关系,因此需要在上一步之后加上 position embedding (768 维)
- 此外还会加入多一个 position 固定为 0 的 class embedding (768 维)
- 随后输入到标准的 Transformer Encoder 中,如右上所示
- 最后 Transformer Encoder 在 position 为 0 的输出会再经过一层 MLP Head,得到图片最终的分类结果
为什么加入 class embedding?
因为 transformer Encoder 的多头注意力机制能够让 patch 之间相互交互,从而可以认为 position 为 0 的输出是结合了多个 patch embedding 后所获得的结果
代码中 cls token 的实现
# ref link: https://github.com/lucidrains/vit-pytorch/blob/main/vit_pytorch/vit.py#L102
class ViT(nn.Module):
def __init__(self):
....
self.cls_token = nn.Parameter(torch.randn(1, 1, dim))
....
def forward(self, img):
....
cls_tokens = repeat(self.cls_token, '1 1 d -> b 1 d', b = b)
x = torch.cat((cls_tokens, x), dim=1)
....
能否像 CNN 那样微调获得更大输入分辨率的 ViT 呢?
存在局限性,在保持 patch 尺寸不变情况下,大分辨率会导致 patch 数量变多,从而让原来的 position embedding 失去意义,作者提出了折衷方案,对 position embedding 做插值来扩充长度,但会导致精度下降
ViT 对比 CNN 不同之处?
ViT 没有很多预定义的先验知识,也就是论文中提到的归纳偏置,这会导致 ViT 在小数据集上的效果不如 CNN 好,下个环节的实验部分也会讲到。
训练&效果评估
ViT 在 ImageNet 分类任务上结果如何?
由上图所示,ViT 在多个数据集上的精度都好于 ResNet
ViT 需要多少数据量才能训练的好?
- 左图实验告诉我们 Transformer 至少要 ImageNet-21K 数据量级来训练才能超过 ResNet
- 右图实验表明只有基于大数据量训练出来的 ViT 才能够在 few-shot 任务上超过 ResNet
ViT 比 CNN 训练起来便宜
如下图,作者在 JFT 和 ImageNet 数据集上训练了多次的模型,每类模型取五次平均获得最终精度,如下图所示
- 低计算量下,Hybrid 模型(CNN+ViT)精度最高
- 在中高计算量下, ViT 精度比 ResNet 高
深入 VIT 内部
左图: 对 Linear Projection 得到的 28 个 Patch Embedding 可视化,可以看到 ViT 学习到了图像低维度的特征表示
中图: 把 9 个 Position Embedding 向量两两做相关性计算,得到相关矩阵后做可视化,可以看到他们是有学习到位置信息的
右图: Attention 模块通过权重来汇集不同 Patches 之间的特征,类似 CNN 中的感受野,可以看到越深层的 Attention 模块,他的权重之和越大,表明它提取了更多 Patches 的特征,也表明具备了更大的感受野
自监督
文中还对 ViT 自监督学习也做了探索,它借鉴了 Bert 中掩码的形式在 ImageNet 上训练分类模型,简单来说就是通过遮住图片部分像素,然后让 Decoder 预测遮住的像素值。
参考链接
https://github.com/lucidrains/vit-pytorch