VIT- Transformers For Image Recognition At Scale


vision transformer讲解非常清晰的一篇文章

ViT原理分析

这个工作本着尽可能少修改的原则,将原版的Transformer开箱即用地迁移到分类任务上面。并且作者认为没有必要总是依赖于CNN,只用Transformer也能够在分类任务中表现很好,尤其是在使用大规模训练集的时候。同时,在大规模数据集上预训练好的模型,在迁移到中等数据集或小数据集的分类任务上以后,也能取得比CNN更优的性能。

下图是原论文中给出的关于Vision Transformer(ViT)的模型框架。简单而言,模型由三个模块组成:

  • Linear Projection of Flattened Patches(Embedding层)
  • Transformer Encoder(图右侧有给出更加详细的结构)
  • MLP Head(最终用于分类的层结构)

在这里插入图片描述

Embedding层

对于标准的Transformer模块,要求输入的是token(向量)序列,即二维矩阵 [num_token, token_dim] ,如下图,token0-9对应的都是向量,以ViT-B/16为例,每个token向量长度为768。
在这里插入图片描述
对于图像数据而言,其数据格式为[H, W, C]是三维矩阵明显不是Transformer想要的。所以需要先通过一个Embedding层来对数据做个变换。

如下图所示,首先将一张图片按给定大小分成一堆Patches。以ViT-B/16为例,将输入图片(224x224)按照 16x16 大小的Patch进行划分,划分后会得到 ( 224 / 16 )2 = 196个Patches。接着通过线性映射将每个Patch映射到一维向量中,以ViT-B/16为例,每个Patche数据shape为[16, 16, 3]通过映射得到一个长度为768的向量(后面都直接称为token)。[224, 224, 3] -> [196, 16, 16, 3] -> [196, 768]
在这里插入图片描述

在代码实现中,直接通过一个卷积层来实现。 以ViT-B/16为例,直接使用一个卷积核大小为16x16,步距为16,卷积核个数为768的卷积来实现。 通过卷积[224, 224, 3] -> [14, 14, 768],然后把H以及W两个维度展平即可[14, 14, 768] -> [196, 768],此时正好变成了一个二维矩阵,正是Transformer想要的.

在输入Transformer Encoder之前注意需要加上[class] token以及Position Embedding。 在原论文中,作者说参考BERT,在刚刚得到的一堆tokens中插入一个专门用于分类的[class]token,这个[class]token是一个可训练的参数 ,数据格式和其他token一样都是一个向量,以ViT-B/16为例,就是一个长度为768的向量,与之前从图片中生成的tokens拼接在一起,例如:对于cat这个图片来说,([1, 768], [196, 768])-> [197, 768] 。然后关于Position Embedding就是之前Transformer中讲到的Positional Encoding,这里的Position Embedding采用的是一个可训练的参数(1D Pos. Emb.) ,是直接叠加在tokens上的(add),所以shape要一样。以ViT-B/16为例,刚刚拼接[class]token后shape是[197, 768],那么这里的Position Embedding的shape也是[197, 768]。

Transformer Encoder层

这里使用的Encoder与nlp的Transformer结构一致,这里不再赘述,如果不了解,可以看之前的博客。

MLP Head层

上面通过Transformer Encoder后输出的shape和输入的shape是保持不变的,以ViT-B/16为例,输入的是[197, 768]输出的还是[197, 768]。这里我们只是需要分类的信息,所以我们只需要提取出[class]token生成的对应结果就行,即[197, 768]中抽取出[class]token对应的[1, 768]。最后,通过一个Linear 全连接层得到类别信息即可
在这里插入图片描述
下面是从其他博客中拷过来的图片,非常详细地讲解了该网络结构
在这里插入图片描述

Hybrid混合模型

混合模型是指将传统CNN特征提取和Transformer进行结合 。论文中使用Resnet50作为特征提取器(注意,这里的resnet与原始resnet结构有一些不同,但大致一样),通过R50 Backbone进行特征提取后,得到的特征矩阵shape是[14, 14, 1024],接着再输入Patch Embedding层,将特征维度调整为transformer需要的维度。后面的过程就与前面一致了。
在这里插入图片描述

  • 1
    点赞
  • 4
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值