逐步理解Swin-Transformer源码

逐步解释Swin-Transformer源码

这一行没什么好说的,就是设置参数值,并读取yaml里的模型参数。

是否采用混合精度训练,混合精度训练可以给合适的参数设置合适的位数,而不是全部为float32

分布式训练。

这一段就是什么梯度累加,学习率缩放等一些优化策略。跳到main函数中

数据预处理,logger写日志的模块,暂时不用管。

跳到构建模型中,详细讲解这一部分。

再跳到SwinTransformer中,再看模型的构建代码时,先整体看一下整体一个架构

这个图想必大家都看到过很多次了,首先输入图片(H*W*3),输入到一个Patch Partition中,这个模块SwinTrasfomer中与VIT相同但又略微不同,他将图片按照4*4划分为一个token,送到Linear Embedding中,进行维度变化,利用一个卷积,将通道数变为初始设置的C,源码中为96.

        再送到Swintransformer Block中运算,需要注意的是,Block并不会改变输入图像的维度。每一个模块下面的*2,*2,*6,*2就是叠加了多少个Block,patch Merging就是将4*4变为8*8,再将通道数变为2C,以此类推。

        Block中主要是一个W-MSA和一个SW-MSA,这就是Swintransformer的创新之处,之后代码中会讲到。

        接下来看代码

这里就是一初始化,源码中都有注释,这里用到再详细解释。

跳入第一个模块

初始化,patch_size就是我们想要将图片按照几乘几去划分,都很好理解,patches_resolution就是图像被划分为了几块,proj就是前面说的,通过一个卷积将通道数映射到我们设置的值上。

绝对位置编码,也就是transformer中的位置embedding

 接下来到了Basic Block中,根据每一个stage的层数来循环相应次数。

需要注意的是,这里与论文中不同,这里的Basic Block是包括Swintransformer Block和Pach merging的。

跳入Swintransformer Block

跳入WindowAttention中

WindowAttention就是创新之处,由于VIT采取全局各个Token之间算attention,这样会导致,计算量非常大,所以这里采用窗的机制,我们将一幅图片划分为不同的窗,每个窗内自己的attention,即W-MSA

建立相对位置的索引表,具体什么意思。

我们知道,Transformer是利用绝对位置编码的,而这里采用相对位置编码。

就是公式中的B,包含着他的一个位置信息。那B到底如何计算,首先B的维度一定是与QKT的维度相同,因为两者要相加嘛,而QKT的维度如何计算,不用计算,QKT其实就是每两个像素点之间求attention,那么肯定是要遍历所有的像素点,并且包含顺序,也就是说是len*len,如下图

比如,我喜欢你,他的相对位置有如图四种情况,那应该如何将相对位置加到QKT中呢?

拉直,每一行以该行的值当作原点叠加

当然,这个加并不是直接加上去,而是我们有一个索引表,也就是刚刚建立的索引表,0找到0代表的值,-1找到-1的,以此类推。这就是相对位置编码的原理。

那在Swintransformer中是怎样设计的呢?继续看代码

这一部分就是计算相对位置的代码, 每一步对应下面的一张图,注意这里使用了广播机制来计算相对位置,也就是coords_flatten(:,:,None),加M-1是因为防止有负数,乘以2M-1是因为,每一行不同元素之间的attention值不能一样,就比如(1,2)和(2,1)

这就是相对位置编码。同时将索引注册为不用学习的参数。

计算Q,K,V同时再利用卷积完成通道数的变化,最后再给索引表初始化。

这就是一个WindowAttention

这就是一个MLP,也没啥好说的

接下来又到了非常重要的地方,前面我们说到,将图片划分为不同的窗,每个窗内计算attention,这样的话就会导致不同的窗之间没有关联,而实际上应该也是有关联的,所以就出现了SW-MSA

如何实现的呢?就是将窗整体右移window_size/2,也就是移动3,如图

如图,移动之后,我们可以看到左上角多出来两条,将其补刀右下角就如图所展示。

移完后,我们还是要去窗内计算attention,对于0中,完全可以实现,并且完成了对于W-MSA中不同窗之间计算attention,但对于1,2,3,4,5,6,7,8来说,却不可以,因为有我们补上去的,实际的这些像素并不在这里,我们不能计算他们之间的attention。

所以就需要mask,来防止两者之间计算mask

这里就是通过slice对图片进行划分并且标号,如我自己画的图那样

如图

 这里之所以加-100,就是为了加到atten矩阵中时,使得不相关的块之间的值很小,这样经过softmax之后就很小了,也就是相关性非常低。

到这里一个Swintransformer模块介绍完了,就是窗不移和窗移连在一起。然后重复[2,2,6,2]次

然后就到

通过一个下采样,使得原本的4*4,变为8*8,相当于卷积中的池化,增大了感受野。

这就是一个stage,然后一共有四个stage循环四次,再经过

数据集采用多少分类,这里就输出多少。再对权值初始化,至此模型搭建完成。

之后,设置模型的优化函数,损失函数,分布式训练等等。

是否继续上次的训练,这个感觉一般都不会用到,暂且不用管。

开始训练,记录开始时间,跳到train_one_epoch中

然后其实就没什么可说的了,提取bach数据,送到模型出,输出预测值,计算损失和梯度,进行梯度累积,达到累计次数,梯度更新并梯度清零。然后就计算准确率啊,损失值啊,时间啊等等。

然后就,看看迭代次数是否达到设置的值,达到保存模型,dist.get_rank()==0是分布式训练中的,这里对分布式训练还不太了解,所以这里也不清楚。

之后就计算top1,top5准确率,计算总体时间,将想要的值通过logger写入日志。

至此,Swin-transformer中的原理及代码介绍完毕。

如有错误,欢迎批评指正!!

 

 

 

 

 

 

 

 

 

 

 

 

 

 

 

 

 

 

 

 

 

 

 

 

 

 

 

 

 

评论 3
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值