ViT理解

最近在阅读transformer的第一篇文章ViT,也作个笔记,供学习使用。希望阅读者有CNN的基础,如YOLO,MobileNets,ResNet等,不然读起来可能比较吃力。当然笔记也尽可能解释清楚。

核心参考资料

① B站深度之眼:CV transformer
② 代码网址:https://github.com/lucidrains/vit-pytorch
③ 原文:AN IMAGE IS WORTH 16X16 WORDS:
TRANSFORMERS FOR IMAGE RECOGNITION AT SCALE

摘要

1、Transfomer在NLP领域已经成为一个标准
2、在CV中,Attention只是作为一个补充或替代某一个结构
3、本文使用纯Transfomer在图像分类任务中取得了一个非常好的成绩
4、在足够在的数据集上训练后,ViT(Vision Transformer)取得了和CNN一样的SOTA(最佳)结果

思想来源

在这里插入图片描述

解释:
1、Transformer在NLP领域大放异彩,我们也想拿来CV领域中用
2、并且不想做太多的改变,直接能用最好
3、具体来说就是把一张图片拆分成很多小片,变成一个序列。这样就能作为Transformer合格的输入(把图片的小片看成NLP中的tokens(words))
注:token可以看做是一个最小的处理单元

结构设计

在这里插入图片描述
对于CV而言,该模型比较陌生的Transformer Encoder和下端的输入端适配,而MLP:多层感知机,具体而言就是一个FC层
Transformer Encoder如右端所示:Norm其实就是Layer Norm,Multi-head Attention为重点。也就意味着Transformer Encoder由L个该模型组成,其中只有Multi-head Attention不熟悉
在输入端:不知道的是图像怎么输入?怎么切分重排?
所以对于ViT陌生的只有两个,Multi-head Attention和输入端怎么适配?下面就开始讲解这两部分

1 Multi-head Attention

1.1 背景介绍(seq2seq)

以机器翻译为例,如下图所示,在建模的时候,我们输入中文狗,经过编译生成了一个像图中一个物体的实例,然后这样的物体在英文里又是怎么写的呢?这样再去解码,解出来是dog。也就是说在机器翻译的时候,两种不同的语言表达的都是同样的意思,而我们在输入一句的时候要经过编码(Encoder),生成一个意思,而这个意思在另外一种语言中是怎么表达的呢?这就需要解码(Decoder)把它解出来。
在这里插入图片描述

以RNN为例,x1,x2,x3,x4就是输入(input),进来之后经过编码(Encoder),编码之后就是中间状态向量c,也就是意思,从c之后就是解码(decoder),然后就输出
在这里插入图片描述
但是事实上,一句话并不需要所有都输入完在翻译,也就是一句话并不需要所有的单词,而是只需要注意到几个单词就能翻译。也就是attention背后的力量。如下图所示:attention具有并行计算和全局视野的优点。这里只需要知道attention的优点即可。

在这里插入图片描述

该图片引自:作者:钟漂亮的太阳

1.2 Self Attention

注:该节、1.3节和3节所有的图片均来自:B站深度之眼CV transformer

Self Attention的计算,实际上是计算相似度。Q(Query):查询,K(Key):关键词,V(Value):价值。公式理解:比如我在CSDN搜索CV,然后我输入的词也就是查询词CV,点击搜索,数据库开始查找有关CV的关键词,此时数据库里的CV和搜索的CV进行匹配,也就是相似度的计算即QK^T,计算之后有个页面返回出来,这个返回出来的页面也就是价值V(Value)。那么√d,直观理解就是缩放,把QK的值缩放到一个合理的水平,softmax其实就是归一化,值在0~1之间,于是就相当于给V一个权重。所以Attention的值越大越好。
在这里插入图片描述
那Q,K,V又是怎么来的呢?这里需要定义三个矩阵,分别为Wq,Wk,Wv。输入x1,x2,x3分别去乘它们,得到的结果便是Q,K,V
在这里插入图片描述
当宏观理解了这个公式,具体推到可查看:作者:太阳花的小绿豆
相信你能很轻易的看懂公式推导。

1.3 Multi-Head Attention

所谓的Multi-Head Attention,其实就是多个Wq,Wk,Wv重复上述的操作,将结果concat在一起。之后再Wo输出。注意:这里Wo是做一个维度的转换,加入有四个self Attention Layer,那么出来concat之后,维度就会变成原来的四倍,但是实际上我们并不需要这么多维度,于是就进行一个维度转换,把维度转为自己需要的。
在这里插入图片描述
Attention的优点:1、提供了多种可能性。比如下图的Law在上面就显得很全局,而下图中的Law便显得很局部;2、注意力是稀疏的。如一个词可能只和几个词有关,虽然一个句子就算很长,它也之和几个单词有关。所以,注意力矩阵可能是一堆0,只有几个1。因此,不希望注意力矩阵维度很高,因为如果很高的话,他的内部会有很多0。所以我们可以用少量多个,这样会好一点。所以我们选择用多头的注意力。
在这里插入图片描述

1.4 Transformer Encoder

讲到这里,我们再来看Transformer Encoder。首先,Embedded Patches先不管,进来之后就是Norm(和BN差不多),然后进入Multi-Head Attention,之后和跳跃连接相加之后又进入Norm,MLP最后和跳跃连接相加输出。就这样一直循环L次。
在这里插入图片描述

2 输入端适配

2.1 图像切分重排

在这里插入图片描述

直接把图片切分,然后编号输入网络即可。这里就衍生出两个问题:
1、既然是切分重排再输入,那么0是什么?它又不从图像中来
2、Position Embedding又是什么?
回答:1、为什么会有Patch 0?因为它需要一个整合信息的向量。这里的Transformer只有Encoder没有Decoder。所以它这里只有一个编码,即一个求表式的一个过程,也即一个图像应该用什么样的表来表示的过程。但是它又没有信息整合输出的能力,所以说就需要一个东西来表示。不懂?举个例子:就比如说,在NLP中,把这个几个切分的图片送入网络经过Encoder之后,出来的是什么呢?就拿1号小片来说,编码输出来之后的结果就表示1号小片在原图像中的意思是啥。同理,2号小片编码输出的结果就表示2号小片在原图像中的意思是啥,一直到9号。所以不能拿每一个孤立的小片做分类,也就是说如果我们拿1号小片来做分类,这显然不对。因为1号小片仅含有一部分物体。那我们能不能把所有的小片编码出来之后加在一起求平均来表示?这也不对,这样的话我们就把物体和背景的信息都含有了,这样变会混淆物体和背景。那么我们能不能把所有的Concat(维度叠加,注意不是相加add)在一起,然后再过MLP?concat的想法很好,但是注意这里编码的维度都很长,而且concat的维度是和切分的片数有关,假如这里切分不是33而是5252,那concat在一起不能显存爆炸?那么既然不能只求一个也不能求平均,那应该怎么办呢?那我再加一个维度(path 0),这个维度0就代表什么类别,然后也是经过Muti-Head Attention计算和剩下9个path的关系,总之就是和剩下9个path的处理一致。
所以简单来说,如果只有原始输出的9个向量,用哪个向量来分类都不行,缺乏整体物体信息。而全计算,计算量又很大。所以需要添加一个可学习的向量patch 0来整合信息。
之后再经过Encoder之后输出也是10个,但是只取了第0个,放到MLP head里面

2.2 位置编码

在这里插入图片描述

图像切分重排后失去了位置信息,也就是说图片空间上的信息完全没有了。那应该怎么办呢?并且Transformer的内部运算与空间信息无关的。我们在算QKV的时候,其实和每一个单词,也就是每一个token都要算一遍QK。也就是说Q要和每个K都要做一个点积,无论是K几,这样遍意味着和位置无关。即和空间信息无关。
但是如果没有位置信息做图像分类,那怎么能行呢?所以需要把位置信息编码重新传进网络。也就是需要Position Emdedding的过程。ViT使用了一个可学习的vector来编码,编码vector和patch vecter直接相加组成输入。可以理解为0123~9为位置编码,后面粉色的为patch vecter。它们的大小维度是一样的。
那么为什么前面每个patch的东西不能加,这里position就可以直接相加,而不是用concat?是因为相加是concat的一种特例,w(i+p)=wi+wp(i表示Input的特征,p表示位置特征),而concat是采用(w1+w2)(i+p)=w(1i)+w(2p)(如果w1=w2),那其实就是相加。也就是说concat有一些情况就是和相加是一样的。其实这里是一个人为的先验,因为矩阵相乘就是一个空间上的转换,就认为位置的转换和input特征的转换,人认为应该用同样的一个w矩阵来建模。另外不能同concat的另外原因是因为concat之后的维度会很大,不利于计算。

3 数据流

在这里插入图片描述
b:batch_size,c:channel(rgb图为3,灰度图为1),hw:宽高。N:切分的块数,这里为88;P2C:假设切好的一个小块像素为3232(P2),c为3,然后将其拉平变为32323=3072。然后进入Linear Projection of Flattened Patches,Flattened其实就是拉平,也就解释了为什么是3072,而Linear Projection就是把维度降低,因为3072的维度太高了,于是过了一个FC就变成了(b,64,1024),注意:这里是64个小块,不是9,9只是一个示意图。然后再加上patch0,path0的维度和降维后的patch一样为1024。之后进入Transformer Encoder,看右图,此时变成了(b,N+1,D)=(b,65,1024),之后经过multi-head attention和MLP维度都是一样的,为(b,65,1024)。至于是怎么做的,后面再讲解。
经过L次以后到输出,看左边图。输出的时候只拿0号位置放到MLP里面来进行分类,后面的64个丢弃。所以这里维度就是(b,1,D)=(b,1,1024),MLP其实就是一个FC,输出(b,1,num_class)

后面章节重点看代码部分

4 训练方法

大规模使用Pre-Train:先在大数据集上预训练,然后到小数据集上Fine Tune。迁移过去之后,需要把原本的MLP Head换掉,换成对应类别数的FC层。处理不同尺寸输入的时候需要对Position Encoding的结果进行插值

5 结果

在这里插入图片描述
1、ViT的性能需要大量的数据,不然ViT Large的性能无法充分发挥。
在这里插入图片描述
2、Attention距离和网络层数关系
Attention的距离可以等价为Conv中的感受野大小。层数越深,Attention跨越的距离越远。但是在最底层,有的head也可以覆盖到很远的距离。这说明它们确实在负责global信息整合
在这里插入图片描述
3、分类的结果和分类的语义高度相关
在这里插入图片描述

6 代码分析

参考:https://github.com/lucidrains/vit-pytorch

找到ViT的入口函数,如下图所示:

在这里插入图片描述
其中各个参数的意思,作者也给出了解释,如下图所示:
在这里插入图片描述
为了更好的解释该函数,我们把它进行了中文批注并且放在图中解释:
在这里插入图片描述
在这里插入图片描述

至此,如果你认真阅读笔记,那么整个宏观流程我们已经掌握了,那么剩下的就是细节的问题,我们一一来分析。

6.1 输入端适配

6.1.1 图像切分重排

当我们输入一批完整的图片(b,c,h,w)是怎么切分重排的呢?按照下图所提供的步骤找到self.to_patch_embedding
在这里插入图片描述
此时,我们找到了图像切分重排的函数,那么应该怎么解释呢?以b=1为例,也就是以一张图片为例。随机创建(b,c,h,w)=(1,3,256,256)的图片,然后(b,c,h,w)=(b,c,(N×p1),(N×p2))。由于p1=p2,故(b,c,(N×p1),(N×p1))=(b,c,(N×p),(N×p))=(1,3,(8×32),(8×32))。图中 Rearrange('b c (h p1) (w p2) ,实际上就是这个意思。注意公式里的 h w实际上就是N。于是整句话就成了(1,3,(8,32),(8,32))->(1,(8,8),(32,32,3))。最后经过Linear变成(1,64,1024)。
在这里插入图片描述
此时,你可能还有疑问,Rearrange其实就是切分重排,那它到底是怎么运算的呢?该问题的解答可以参考官网的解释。

https://github.com/arogozhnikov/einops/blob/master/docs/1-einops-basics.ipynb

现在我们已经完成了蓝色框的讲解,接下来我们讲解patch0的构造
在这里插入图片描述

6.1.2 patch0构造

我们先回到刚才代码停留的页面,找到97,114,115行
在这里插入图片描述
我们首先创建一个可学习的向量(1,1,dim),根据上面的分析这里dim其实等于1024,其实就是和每一个小片的维度保持一致。下面来分析114,115行。repeat的目的其实就是我们输入的图片是按批次,也是batch处理的,所以batch有多大,我们就repeat几次。我们这里把batch取为了1,所以就repeat一次。这样再和6.1.1小结的输出进行concat就变成了(1,65,1024)
在这里插入图片描述
现在我们已经完成了蓝色框中的粉色框(也就是数字后面的小框)讲解,接下来我们讲解positional embedding
在这里插入图片描述

6.1.3 positional embedding

我们先回到刚才代码停留的页面,找到96,116行
在这里插入图片描述
先看第96行,这里先定义了一个pos_embedding为torch.randn(1, num_patches + 1, dim),这里为什么要+1呢?正如6.1.2小节所分析的,我们创建了patch 0,并且之后和patch进行concat,也就是维度叠加,现在维度已经比之前二队patch多了1。所以要+1。所以这句话的意思就是随机生成(1,65,1024)的矩阵/张量/向量,这样大小一致,就可以和patch进行add,也就是加和。这里我们在2.2 小节位置编码已经解释过。
现在来看第116行,这里有个问题,代码后面是(n+1)的切片来限制个数,然而第96行已经定义好了是num_patches + 1。所以这里完全没有必要载进行一次切片。所以我们这里可以改成

x += self.pos_emdbedding

注意:这里是+=,也就是和之前的patch进行add。至此这里维度还是(b,N+1,D)=(b,65,1024)=(1,65,1024)。这里第二个等号成立是因为我们拿一张图片进行分析。
现在我们分析完了蓝色框里的剩余部分,接下来分析Transformer Encoder。
在这里插入图片描述

6.2 Transformer Encoder

回到代码页面,我们找到Transformer的定义,可以看到代码封装了PreNorm,Attention,FeedForward三个类。并且从75,76行可以看出,Transformer Encoder被分为两个部分,以中间加号为界,下面的第一部分为Norm->Multi-Head Attention->跳跃连接。上面剩下的为第二部分。下面我们依次封装的三个类。
在这里插入图片描述

6.2.1 PreNorm

可以看出PreNorm输入的是一个维度和一个类。而它的forward要先执行norm,然后再执行fn,最后把所有的参数都拿出来。这就相当于上图红框一样,把两步操作写在一起了。
在这里插入图片描述
现在来看norm,第17行给出的norm为nn.LayerNorm(dim)。如下图所示,BN是沿着batch维度做归一化。LayerNorm 是对这单样本的所有维度做归一化,也就意味着我们得知道单样本有多少维,所以样本的维度应该作为输入,不然不知道要怎么算维度。所以这也解释了在定义LayerNorm的时候,需要把dim作为输入。
在这里插入图片描述

注:图片引自:知乎:爱吃牛油果的璐璐

第18行的fn,也就是70,71行的Attention,FeedForward。下面我们依次讲解这两个类。

6.2.2 Attention

该小结的代码分析如下图,相信大家看了前面的第1节,再来分析这个应该是没多大问题的。这里经过wo之后,维度变成(b,65,1024),图片为1时,即b=1,维度为(1,65,1024)。
在这里插入图片描述
至此,我们分析完了Attention部分。注意下图给出的heads=16,而上面代码为8。这里请不要纠结,下图给出的为16是自己在调用ViT的时候赋值给的,而上图的8是原代码给的。事实上,这里我们可以任意赋值,不必太纠结。
在这里插入图片描述

6.2.3 FeedForward

该部分正如我们之前分析的那样,本质上就是一个FC。非常简单,这里不做过多介绍。
在这里插入图片描述
注:上图右部分图片引自:博主:太阳花的小绿豆

至此我们分析完了Transformer Encoder的剩下部分。之后的MLP Head然后再分类实际上和CNN的一致,相信如果看到ViT这篇论文,基础的早已掌握,故后面的不在分析。
在这里插入图片描述
另外,值得一提的是,可能有其他博客写的维度和此篇的不一致,这是可理解的。大家不要在乎为什么几篇博客的通道数不一致,重在理解文章的思想即可。
最后附上用ViT做猫狗分类的例程,我自己只跑了一个epoch,效果还行。
在这里插入图片描述

Training Visual Transformer on Dogs vs Cats Data

  • 12
    点赞
  • 47
    收藏
    觉得还不错? 一键收藏
  • 打赏
    打赏
  • 3
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 3
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

CinzWS

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值