逐步解释Swin-Transformer源码
这一行没什么好说的,就是设置参数值,并读取yaml里的模型参数。
是否采用混合精度训练,混合精度训练可以给合适的参数设置合适的位数,而不是全部为float32
分布式训练。
这一段就是什么梯度累加,学习率缩放等一些优化策略。跳到main函数中
数据预处理,logger写日志的模块,暂时不用管。
跳到构建模型中,详细讲解这一部分。
再跳到SwinTransformer中,再看模型的构建代码时,先整体看一下整体一个架构
这个图想必大家都看到过很多次了,首先输入图片(H*W*3),输入到一个Patch Partition中,这个模块SwinTrasfomer中与VIT相同但又略微不同,他将图片按照4*4划分为一个token,送到Linear Embedding中,进行维度变化,利用一个卷积,将通道数变为初始设置的C,源码中为96.
再送到Swintransformer Block中运算,需要注意的是,Block并不会改变输入图像的维度。每一个模块下面的*2,*2,*6,*2就是叠加了多少个Block,patch Merging就是将4*4变为8*8,再将通道数变为2C,以此类推。
Block中主要是一个W-MSA和一个SW-MSA,这就是Swintransformer的创新之处,之后代码中会讲到。
接下来看代码
这里就是一初始化,源码中都有注释,这里用到再详细解释。
跳入第一个模块
初始化,patch_size就是我们想要将图片按照几乘几去划分,都很好理解,patches_resolution就是图像被划分为了几块,proj就是前面说的,通过一个卷积将通道数映射到我们设置的值上。
绝对位置编码,也就是transformer中的位置embedding
接下来到了Basic Block中,根据每一个stage的层数来循环相应次数。
需要注意的是,这里与论文中不同,这里的Basic Block是包括Swintransformer Block和Pach merging的。
跳入Swintransformer Block
跳入WindowAttention中
WindowAttention就是创新之处,由于VIT采取全局各个Token之间算attention,这样会导致,计算量非常大,所以这里采用窗的机制,我们将一幅图片划分为不同的窗,每个窗内自己的attention,即W-MSA
建立相对位置的索引表,具体什么意思。
我们知道,Transformer是利用绝对位置编码的,而这里采用相对位置编码。
就是公式中的B,包含着他的一个位置信息。那B到底如何计算,首先B的维度一定是与QKT的维度相同,因为两者要相加嘛,而QKT的维度如何计算,不用计算,QKT其实就是每两个像素点之间求attention,那么肯定是要遍历所有的像素点,并且包含顺序,也就是说是len*len,如下图
比如,我喜欢你,他的相对位置有如图四种情况,那应该如何将相对位置加到QKT中呢?
拉直,每一行以该行的值当作原点叠加
当然,这个加并不是直接加上去,而是我们有一个索引表,也就是刚刚建立的索引表,0找到0代表的值,-1找到-1的,以此类推。这就是相对位置编码的原理。
那在Swintransformer中是怎样设计的呢?继续看代码
这一部分就是计算相对位置的代码, 每一步对应下面的一张图,注意这里使用了广播机制来计算相对位置,也就是coords_flatten(:,:,None),加M-1是因为防止有负数,乘以2M-1是因为,每一行不同元素之间的attention值不能一样,就比如(1,2)和(2,1)
这就是相对位置编码。同时将索引注册为不用学习的参数。
计算Q,K,V同时再利用卷积完成通道数的变化,最后再给索引表初始化。
这就是一个WindowAttention
这就是一个MLP,也没啥好说的
接下来又到了非常重要的地方,前面我们说到,将图片划分为不同的窗,每个窗内计算attention,这样的话就会导致不同的窗之间没有关联,而实际上应该也是有关联的,所以就出现了SW-MSA
如何实现的呢?就是将窗整体右移window_size/2,也就是移动3,如图
如图,移动之后,我们可以看到左上角多出来两条,将其补刀右下角就如图所展示。
移完后,我们还是要去窗内计算attention,对于0中,完全可以实现,并且完成了对于W-MSA中不同窗之间计算attention,但对于1,2,3,4,5,6,7,8来说,却不可以,因为有我们补上去的,实际的这些像素并不在这里,我们不能计算他们之间的attention。
所以就需要mask,来防止两者之间计算mask
这里就是通过slice对图片进行划分并且标号,如我自己画的图那样
如图
这里之所以加-100,就是为了加到atten矩阵中时,使得不相关的块之间的值很小,这样经过softmax之后就很小了,也就是相关性非常低。
到这里一个Swintransformer模块介绍完了,就是窗不移和窗移连在一起。然后重复[2,2,6,2]次
然后就到
通过一个下采样,使得原本的4*4,变为8*8,相当于卷积中的池化,增大了感受野。
这就是一个stage,然后一共有四个stage循环四次,再经过
数据集采用多少分类,这里就输出多少。再对权值初始化,至此模型搭建完成。
之后,设置模型的优化函数,损失函数,分布式训练等等。
是否继续上次的训练,这个感觉一般都不会用到,暂且不用管。
开始训练,记录开始时间,跳到train_one_epoch中
然后其实就没什么可说的了,提取bach数据,送到模型出,输出预测值,计算损失和梯度,进行梯度累积,达到累计次数,梯度更新并梯度清零。然后就计算准确率啊,损失值啊,时间啊等等。
然后就,看看迭代次数是否达到设置的值,达到保存模型,dist.get_rank()==0是分布式训练中的,这里对分布式训练还不太了解,所以这里也不清楚。
之后就计算top1,top5准确率,计算总体时间,将想要的值通过logger写入日志。
至此,Swin-transformer中的原理及代码介绍完毕。
如有错误,欢迎批评指正!!