Swin Transformer详解
1. 前言
这篇文章主要讲述的是swin transofmrer这个基本的网络架构。虽然现在是2023年,离swin transformer发布已经相隔三年之久了,但是这篇文章点此直达依然在很多下游任务中表现出SOTA的水平。
官方代码:点击直达
2. 网络详解
首先我们来放出一张Swin Transformer的模型架构图,具体的细节我们先暂时不详谈。
在这里我们可以看到,在Swin Transformer当中,具有如下特点:
Patch Partition
结构- Stage结构,除了第一个Stage之外,其余的结构包含一个
Patch Merging
结构以及Swin Transformer Block
结构- 每个Stage的Swin Transformer Block下面,表示的是这个Block会重复多少次。比如Stage 1下面就重复2次,Stage 3下面就重复6次…
- 在Swin Block当中,有一个
W-MSA
和一个SW-MSA
结构,是Swin Transformer作者为了减少相对于Vision Transformer复杂地计算而设计的。
好了,分析完网络架构之后,我们就来逐个击破!
2.1 Patch Partition结构
所谓的Patch Partition结构,它的作用就是将一张图像,拿着刀切成一块块的小图像。比如说我输入了一个128x128的图像,然后我会把这张图像切分成32x32个大小的图像块,每个图像块占4x4个像素,如图展示了切分之后4x4的图像块。
然后Patch Partition的操作就是,将这些像素值按照Channel方向进行展平,再套一个Layer Norm的操作。这里值得注意的是,因为我们输入的图像一般都是RGB三通道的,那么经过Patch Partition操作之后,我们得到的图像的维度值就变成了
H
4
×
W
4
×
48
\frac {H}{4} \times \frac {W}{4} \times 48
4H×4W×48
因为48=4(H)x4(W)x3(channel)
。
2.2 Patch Merging 结构
所谓的Patch Merging 操作,他的作用就是,把相同颜色的块,拿出来拼接到一起。然后分别进行Concat、LayerNorm、以及Linear,把它变成2个2X2的操作。
我们还是以4X4的feature map为例,Patch Merging主要做了一下事情:
- 拿出相同颜色的块,也就是把4x4变成4个2x2.
- 把4个2x2的块,按照channel维度进行拼接
- LayerNorm
- 将4个2x2组合成的feature map通过线性变换变成2个2x2组成的features map.
2.3 W-MSA和SW-MSA
接下来是重点,我们来放一张图。
普通的self-attention是,每个像素都要和其他像素做点乘运算,消耗极大。
而我们的W-MSA,只是在窗口内部进行MSA计算,大大降低了复杂度。
但是这样的话,会出现一个问题,那就是窗口和窗口之间是没有交互的,这不利于信息之间的融合和计算,为了解决这个问题,作者提出了SW-MSA(Shifted Window Self-Attention)。
也就是如下图所示,黄色线沿着红色虚线的箭头向下平移。
然后为了实现并行计算,作者通过将其巧妙地组合,进行划分计算。
然后呢,我们知道,下图中绿色的区域,B区域和它左边的区域在原图中是不相邻的,所以我们不想计算他们的注意力(计算没用呀!)
那么我们采用将不相关的权重-100的方式进行处理,如下图所示:
a
0
,
0
表示的是第
0
个像素和第
0
个像素的权重,其余以此类推
a_{0,0}表示的是第0个像素和第0个像素的权重,其余以此类推
a0,0表示的是第0个像素和第0个像素的权重,其余以此类推
计算完成之后,我们再挪回去,即可结束计算了。
在这里还需要提一点的就是,SW-MSA怎么决定需要挪第几行和第几列呢?如果你的Window是3X3
的,那么
l
i
n
e
s
=
⌊
3
2
⌋
lines=\lfloor \frac {3}{2} \rfloor
lines=⌊23⌋,即移动第一行和第一列。
4. relative position bias
最后一个内容就是相对位置偏置。
首先我们来回顾一下绝对位置索引,以下图feature map为例,第一个编号是(0,0),第二个是(0,1)…
那么相对位置索引就是:
以蓝色块为参照点,蓝色点为(0,0),其他的点的偏置索引就是(0,0)减去(m,n),m.n表示的是其他像素的位置,其他类似。然后在行方向上进行展平,之后组成一个4x4大小的位置索引。
之后,行、列的横纵坐标分别加上M-1,行标乘以2M-1,然后行列相加,就可以把二维位置索引转化成一维位置索引。
举个例子:
feature map的大小是2x2,那么M=2,以(0,-1)为例,
(0,-1) + (2x1-1) = (1,0) //行加上M-1
(1,0) x (2x2-1) = (3,0) // 行乘以2M-1
3 + 0 = 3 //行列相加
然后我们根据relative bias table
往索引表中填值。(这个relative positon bias table是可训练的,所以里面的值是变化的)
引用
[1] 霹雳啪啦大佬的B站视频
[2] Swin Transformer论文