Swin Transformer


讲解非常透彻一篇swin transformer博客

引言

目前Transformer引入到图像领域所面临的的主要挑战是:

  • 视觉实体变化大,在不同场景下视觉Transformer性能未必很好
  • 图像分辨率高,像素点多,Transformer基于全局自注意力的计算导致计算量较大

Swin Transformer 就是为了解决这两个问题所提出的一种通用的视觉架构。Swin Transformer 引入 CNN 中常用的层次化构建方式。

提出了一种包含滑窗操作,具有层级设计 的Swin Transformer。

其中滑窗操作包括不重叠的local window,和重叠的cross-window将注意力计算限制在一个窗口中一方面能引入CNN卷积操作的局部性,另一方面能节省计算量

下面对比一下swin transformer与vision transformer的不同点:

  • Swin Transformer使用了类似卷积神经网络中的层次化构建方法(Hierarchical feature maps),比如特征图尺寸中有对图像下采样4倍的,8倍的以及16倍的,这样的backbone有助于在此基础上构建目标检测,实例分割等任务。而在之前的Vision Transformer中是一开始就直接下采样16倍,后面的特征图也是维持这个下采样率不变。
  • 在Swin Transformer中使用了Windows Multi-Head Self-Attention(W-MSA)的概念,比如在下图的4倍下采样和8倍下采样中,将特征图划分成了多个不相交的区域(Window),并且Multi-Head Self-Attention只在每个窗口(Window)内进行 。相对于Vision Transformer中直接对整个(Global)特征图进行Multi-Head Self-Attention,这样做的目的是能够减少计算量的,尤其是在浅层特征图很大的时候。这样做虽然减少了计算量但也会隔绝不同窗口之间的信息传递,所以在论文中作者又提出了 Shifted Windows Multi-Head Self-Attention(SW-MSA)的概念,通过此方法能够让信息在相邻的窗口中进行传递,后面会细讲。
    在这里插入图片描述

整体架构

在这里插入图片描述
整个模型采取层次化的设计,一共包含4个Stage,每个stage都会缩小输入特征图的分辨率,像CNN一样逐层扩大感受野。

  • 首先将图片输入到Patch Partition模块中进行分块,即每4x4 相邻的像素为一个Patch,然后在channel方向展平(flatten)。假设输入的是RGB三通道图片,那么每个patch就有 4x4=16个像素,然后每个像素有R、G、B三个值所以展平后是16x3=48,所以通过Patch Partition后图像shape由 [H, W, 3]变成了 [H/4, W/4, 48]。然后在通过Linear Embeding层对每个像素的channel数据做线性变换,由48变成C,即图像shape再由 [H/4, W/4, 48] 变成了 [H/4, W/4, C]其实在源码中Patch Partition和Linear Embeding就是直接通过一个卷积层实现的,和之前Vision Transformer中讲的 Embedding层结构一模一样。
  • 然后就是通过四个Stage构建不同大小的特征图,除了Stage1中先通过一个Linear Embeding层外,剩下三个stage都是先通过一个Patch Merging层进行下采样(后面会细讲)。然后都是重复堆叠Swin Transformer Block注意这里的Block其实有两种结构,如图(b)中所示,这两种结构的不同之处仅在于一个使用了W-MSA结构,一个使用了SW-MSA结构。而且这两个结构是成对使用的,先使用一个W-MSA结构再使用一个SW-MSA结构。所以你会发现堆叠Swin Transformer Block的次数都是偶数(因为成对使用)。
  • 最后对于分类网络,后面还会接上一个Layer Norm层、全局池化层以及全连接层得到最终输出。图中没有画,但源码中是这样做的。

接下来分别Patch merging, W-MSA , SW-MSA 以及非常重要的相对位置偏置(relative position bias)进行讲解。

Patch Merging

前面有说,在每个Stage中首先要通过一个Patch Merging层进行下采样(Stage1除外)。如下图所示,假设输入Patch Merging的是一个4x4 大小的单通道特征图(feature map),Patch Merging会将每个2x2的相邻像素划分为一个patch,然后将每个patch中相同位置(同一颜色)像素给拼在一起就得到了4个feature map。接着将这四个feature map在深度方向进行concat拼接,然后在通过一个LayerNorm层。最后通过一个全连接层在feature map的深度方向做线性变化,将feature map的深度由4C变成2C。通过这个简单的例子可以看出,通过Patch Merging层后,feature map的高和宽会减半,深度会翻倍。

与yolo v2中使用的结构类似
在这里插入图片描述

W-MSA结构

引入Windows Multi-head Self-Attention(W-MSA)模块是为了减少计算量 。如下图所示,左侧使用的是普通的Multi-head Self-Attention(MSA)模块,对于feature map中的每个像素(或称作token,patch)在Self-Attention计算过程中需要和所有的像素去计算。但在图右侧,在使用Windows Multi-head Self-Attention(W-MSA)模块时,首先将feature map按照MxM(例子中的M=2)大小划分成一个个Windows,然后单独对每个Windows内部进行Self-Attention
在这里插入图片描述
那么两者的计算量具体差了多少呢?
请添加图片描述

  • h代表feature map的高度
  • w代表feature map的宽度
  • C代表feature map的深度
  • M代表每个窗口(Windows)的大小

它的计算主要使用的是self-attention的公式,
请添加图片描述

MSA的计算量

首先是将输入矩阵 i 转化为q, k, v 三个矩阵, [hw, c] x [c, c] = [hw, c] ,那么计算量就是 3 hw x c^2 ,接着是qk 相乘得到attention, [hw, c] x [c, hw] = [hw, hw], 计算量是 hw^2 x c ,然后是attention与v相乘,[hw, hw] x [hw, c] = [hw, c] ,计算量是 hw^2 x c, 最后经过一层输出层, [hw, c] x [c, c] = [hw, c] ,计算量就是 hw x c^2 ,那么最后的计算量就是 4hwC^2 + 2(hw)^2 C

W-MSA计算量

对于W-MSA模块首先要将feature map划分到一个个窗口(Windows)中,假设每个窗口的宽高都是M,那么总共会得到 (h/M x w/M) 个窗口,然后对每个窗口内使用多头注意力模块,所以最后的计算量应该是:
请添加图片描述

SW-MSA结构

前面有说,采用W-MSA模块时,只会在每个窗口内进行自注意力计算,所以窗口与窗口之间是无法进行信息传递的。为了解决这个问题,作者引入了Shifted Windows Multi-Head Self-Attention(SW-MSA)模块,即进行偏移的W-MSA。如下图所示,左侧使用的是刚刚讲的W-MSA(假设是第L层),那么根据之前介绍的W-MSA和SW-MSA是成对使用的,那么第L+1层使用的就是SW-MSA(右侧图)。根据左右两幅图对比能够发现窗口(Windows)发生了偏移(可以理解成窗口从左上角分别向右侧和下方各偏移了 M/2 个像素)。看下偏移后的窗口(右侧图),比如对于第一行第2列的2x4的窗口,它能够使第L层的第一排的两个窗口信息进行交流。再比如,第二行第二列的4x4的窗口,他能够使第L层的四个窗口信息进行交流,其他的同理。那么这就解决了不同窗口之间无法进行信息交流的问题。

在这里插入图片描述
根据上图,可以发现通过将窗口进行偏移后,由原来的4个窗口变成9个窗口了。后面又要对每个窗口内部进行MSA,这样做感觉又变麻烦了。为了解决这个麻烦,作者又提出而了Efficient batch computation for shifted configuration,一种更加高效的计算方法。下面是原论文给的示意图。
在这里插入图片描述
下图是它的详细示意图
在这里插入图片描述
在这里插入图片描述
如上图可以发现,当移动结束后,4是一个单独的窗口;将5和3合并成一个窗口;7和1合并成一个窗口;8、6、2和0合并成一个窗口。这样又和原来一样是4个4x4的窗口了,所以能够保证计算量是一样的。

这里肯定有人会想,把不同的区域合并在一起(比如5和3)进行MSA,这信息不就乱窜了吗? 是的,为了防止这个问题,在实际计算中使用的是masked MSA 即带蒙板mask的MSA,这样就能够通过设置蒙板来隔绝不同区域的信息了 。关于mask如何使用,可以看下下面这幅图,下图是以上面的区域5和区域3为例。
在这里插入图片描述
如上图所示,对于该4x4大小的window,最后生成的attention一定是一个16 x 16 的矩阵,以第0 元素为例,生成的α0,0到α0,15个attention值,其中有8个元素是区域3的点,但我们只想让它与区域5中的点attention,那么我们可以将像素0与区域3中的所有像素匹配结果都减去100,由于α的值都很小,一般都是零点几的数,将其中一些数减去100后在通过SoftMax得到对应的权重都等于0了 。所以对于像素0而言实际上还是只和区域5内的像素进行了MSA。

Relative Position Bias(相对位置偏置)

那这个相对位置偏执是加在哪的呢,根据论文中提供的公式可知是在Q和K进行匹配并除以根号d后加上相对位置偏置B。
请添加图片描述
如下图,假设输入的feature map高宽都为2,那么首先我们可以构建出每个像素的绝对位置(左下方的矩阵),对于每个像素的绝对位置是使用行号和列号表示的。比如蓝色的像素对应的是第0行第0列所以绝对位置索引是(0,0),接下来再看看相对位置索引。首先看下蓝色的像素,在蓝色像素使用q与所有像素k进行匹配过程中,是以蓝色像素为参考点。然后用蓝色像素的绝对位置索引与其他位置索引进行相减,就得到其他位置相对蓝色像素的相对位置索引。例如黄色像素的绝对位置索引是(0,1),则它相对蓝色像素的相对位置索引为(0,0)-(0,1)=(0,-1) 。接下来将每个相对位置索引矩阵按行展平,并拼接在一起可以得到下面的4x4 矩阵 。
在这里插入图片描述
请注意,我这里描述的一直是相对位置索引,并不是相对位置偏执参数。因为后面我们会根据相对位置索引去取对应的参数。比如说黄色像素是在蓝色像素的右边,所以相对蓝色像素的相对位置索引为(0,-1),绿色像素是在红色像素的右边,所以相对红色像素的相对位置索引为(0,-1)。可以发现这两者的相对位置索引都是(0,-1),所以他们使用的相对位置偏执参数都是一样的。为了能够使用一维向量来表示这个矩阵,作者使用了如下方法。
在这里插入图片描述
接着将所有的行标都乘上2M-1.
在这里插入图片描述
最后将行标和列标进行相加
在这里插入图片描述
那么现在就得到了相对位置索引 , 并不是相对位置偏置参数 , 真正使用到的可训练参数B,是保存在relative position bias table 表里的,这个表的长度是等于(2M-1)X(2M-1) 的,那么上述公式中的相对位置偏执参数B是根据上面的相对位置索引表根据查relative position bias table 表得到的,如下图所示。
在这里插入图片描述

  • 1
    点赞
  • 9
    收藏
    觉得还不错? 一键收藏
  • 1
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值