Vision Transformer 原理和效果速览

An image is worth 16x16 words: transformers for image recognition at scale
Google Research, Brain Team
https://arxiv.org/abs/2010.11929

是啥


Transformer 已经是 NLP 领域中很火很香很重要的模型了,这篇文章把 Transformer 应用到 CV 任务上了,提出了 Vision Transformer (简称 ViT) 模型,它的效果超过了基于卷积神经网络的 SOTA 模型。

至此,NLP 和 CV 任务都可以使用相同模型架构来处理,也为后续多模态算法的发展提供了强有力的支持

架构


ViT 架构的总览图如下
在这里插入图片描述

运行流程如下

  1. 给定一张图,划分为多个 patch,上图为九宫格,也就是九个 patch
  2. 摊开变成了序列后输入到 linear Projection (全连接)获得 patch embedding (768 维)
  3. 图片每个 patch 有顺序关系,因此需要在上一步之后加上 position embedding (768 维)
  4. 此外还会加入多一个 position 固定为 0 的 class embedding (768 维)
  5. 随后输入到标准的 Transformer Encoder 中,如右上所示
  6. 最后 Transformer Encoder 在 position 为 0 的输出会再经过一层 MLP Head,得到图片最终的分类结果

为什么加入 class embedding?

因为 transformer Encoder 的多头注意力机制能够让 patch 之间相互交互,从而可以认为 position 为 0 的输出是结合了多个 patch embedding 后所获得的结果

代码中 cls token 的实现

# ref link: https://github.com/lucidrains/vit-pytorch/blob/main/vit_pytorch/vit.py#L102

class ViT(nn.Module):
	def __init__(self):
			....
      self.cls_token = nn.Parameter(torch.randn(1, 1, dim))
      ....
  
  def forward(self, img):
      ....
      cls_tokens = repeat(self.cls_token, '1 1 d -> b 1 d', b = b)
      x = torch.cat((cls_tokens, x), dim=1)
      ....

能否像 CNN 那样微调获得更大输入分辨率的 ViT 呢?

存在局限性,在保持 patch 尺寸不变情况下,大分辨率会导致 patch 数量变多,从而让原来的 position embedding 失去意义,作者提出了折衷方案,对 position embedding 做插值来扩充长度,但会导致精度下降

ViT 对比 CNN 不同之处?

ViT 没有很多预定义的先验知识,也就是论文中提到的归纳偏置,这会导致 ViT 在小数据集上的效果不如 CNN 好,下个环节的实验部分也会讲到。

训练&效果评估


ViT 在 ImageNet 分类任务上结果如何?

在这里插入图片描述

由上图所示,ViT 在多个数据集上的精度都好于 ResNet

ViT 需要多少数据量才能训练的好?

  • 左图实验告诉我们 Transformer 至少要 ImageNet-21K 数据量级来训练才能超过 ResNet
  • 右图实验表明只有基于大数据量训练出来的 ViT 才能够在 few-shot 任务上超过 ResNet

在这里插入图片描述

ViT 比 CNN 训练起来便宜

如下图,作者在 JFT 和 ImageNet 数据集上训练了多次的模型,每类模型取五次平均获得最终精度,如下图所示

  • 低计算量下,Hybrid 模型(CNN+ViT)精度最高
  • 在中高计算量下, ViT 精度比 ResNet 高

在这里插入图片描述

深入 VIT 内部

在这里插入图片描述

左图: 对 Linear Projection 得到的 28 个 Patch Embedding 可视化,可以看到 ViT 学习到了图像低维度的特征表示

中图: 把 9 个 Position Embedding 向量两两做相关性计算,得到相关矩阵后做可视化,可以看到他们是有学习到位置信息的

右图: Attention 模块通过权重来汇集不同 Patches 之间的特征,类似 CNN 中的感受野,可以看到越深层的 Attention 模块,他的权重之和越大,表明它提取了更多 Patches 的特征,也表明具备了更大的感受野

自监督

文中还对 ViT 自监督学习也做了探索,它借鉴了 Bert 中掩码的形式在 ImageNet 上训练分类模型,简单来说就是通过遮住图片部分像素,然后让 Decoder 预测遮住的像素值。

参考链接


https://github.com/lucidrains/vit-pytorch

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值