swin-transformer原理介绍

作者:Vincent Liu
链接:https://www.zhihu.com/question/521494294/answer/2492957365
来源:知乎
 

Swin Transformer 的提出可以说是里程碑式的,在学术界引起了巨大的反响,网络上也有许多讲解的教程,这一篇图解Swin Transformer就写得非常棒,另外强烈推荐Zhu老师在b站的精读视频:Swin Transformer论文精读【论文精读】_哔哩哔哩_bilibili

这篇文章就来记录一下我对 Swin Transformer 的理解。

首先,在 PVT (Pyramid Vision Transformer) 中我已经分析过 ViT 存在的一些缺陷,本文就不再讨论,简单来说就是网络设计的问题和自注意力机制的显存占用问题。下面直接讨论 Swin Transformer 的网络结构。

Swin Transformer 的网络结构

金字塔结构

与 PVT 一样,Swin Transformer 在网络的设计上也实现了金字塔的结构,如下图所示(实际上下面这张图就是从 PVT 论文中摘取的)。

网络结构对比

上图从左到右,我们可以看到

  • CNN:金字塔结构,网络越深,feature map 尺寸越小,channel 数越多。
  • ViT:柱状结构,单一的 feature map 尺寸。
  • Swin Transformer:金字塔结构,网络越深,feature map 尺寸越小,channel 数越多。

这样设计的好处就是能够像 CNN 网络一样处理不同尺度的特征。

Swin Transformer 的总体结构

Swin Transformer 总体结构

从上图我们可以观察到在输入端有一个 Patch Partition 的操作,也就是 Vision Transformer 常规的切图。然后是经过一个线性映射进入第一个 Swin Transformer Block,从而完成 Stage 1 这个模块。Stage 2 的输入特征图大小为 H4×W4×2C\frac{H}{4} \times \frac{W}{4} \times 2C ,之后每个 stage 都会把特征图大小减小一半,通道数增加一倍,也就实现了所谓的金字塔结构。

注意上图 Stage 1 里面的 Swin Transformer Block 下面写了个 x2,这个表示的是这个模块中有两个部分,分别是 Window-based Multi-head Self-attention (W-MSA) 和 Shifted Window-based Multi-head Self-attention (SW-MSA),这个 SW-MSA 是紧跟着 W-MSA 的,属于串联的关系。

在上图 Stage 3 里面我们可以看到 x6,这代表了 Stage 3 有3组 W-MSA + SW-MSA 串联在一起。

W-MSA & SW-MSA

W-MSA

从 W-MSA 说起,它的设计主要是为了解决 Vision Transformer 的自注意力机制显存占用高的问题。顾名思义,Window-based Multi-head Self-attention 就是把自注意力机制限制在了一个窗口中。

如下图所示,假设输入特征图的大小为 H×W=56×56H \times W = 56 \times 56 ,num_patches 为 8×88 \times 8 ,每个 patch 的大小为 7×77 \times 7 ,在这个设定下,MSA 计算时 Q×KQ \times K 会涉及到 8×8=648 \times 8 = 64 个 patches 的相乘。相对而言,W-MSA 将自注意力的计算限制在了每一个窗口中,也就是下图的绿色窗口中,每个窗口中的 patch 只会和同一个窗口内部的其他 patch 计算注意力,Q×KQ \times K 的计算只会涉及到窗口中 4×4=164 \times 4 = 16 个 patches 的相乘。这样就限制了 Q×KQ \times K 运算时的显存占用。

W-MSA

这里分别对 MSA 和 W-MSA 的自注意力机制的计算复杂度做一个估计。

如下图所示,对于 MSA:

  • 将大小为 HWCHWC 的输入特征图通过全连接层转换成大小为 HWCHWC 的 、、Q、K、VQ、K、V :3HWC23HWC^2
  • Q×KQ \times K 计算得到 Attention Matrix: (HW)2C(HW)^2C
  • Attention Matrix 乘以 V: (HW)2C(HW)^2C
  • 最后再接一个全连接层: HWC2HWC^2

全部加起来就是 4HWC2+2(HW)2C4HWC^2+2(HW)^2C ,相对于输入特征图大小而言计算复杂度就是 O(N2)O(N^2) 。

对于 W-MSA:

  • 前面 MSA 的推导中所有的 HWHW 都被限制在了一个尺寸为 M×MM \times M 的窗口中
  • 对于一个窗口中的计算复杂度,我们可以用 M2M^2 替代 4HWC2+2(HW)2C4HWC^2+2(HW)^2C 式子里的 HWHW ,得到 4M2C2+2(M)4C4M^2C^2+2(M)^4C

一共有HWM2\frac{HW}{M^2} 个窗口,所以总的计算复杂度是 HWM2×(4M2C2+2(M)4C)=4HWC2+2M2HWC\frac{HW}{M^2} \times (4M^2C^2+2(M)^4C) = 4HWC^2 + 2M^2HWC 。因为 MM 为窗口的尺寸,是一个较小的固定值,取决于我们对网络的设计,所以可以看成一个常数,在这种情况下相对于输入特征图大小而言总的计算复杂度就是 O(N)O(N) 。

W-MSA 计算复杂度

W-MSA 把 MSA 的计算复杂度从 O(n2)O(n^2) 降低到了 O(n)O(n) ,使得 Vision Transformer 的显存占用不再是个瓶颈。然而 W-MSA 也不是完美的,将注意力限制在窗口中也必然导致全局注意力丧失的问题,为了解决这个问题,作者在 W-MSA 后面加了一个 SW-MSA。

SW-MSA

从下图中可以很直观地看到,SW-MSA 的提出主要是为了补充位于不同窗口中的 patch 之间的注意力。

W-MSA(左)SW-MSA(右)

SW-MSA 最有意思的地方在于其代码的实现上。先说一下 W-MSA 在代码上的实现,其实是很直观的,首先是把特征图平均切分一下,然后在每个窗口内部做自注意力的计算。然而到了 SW-MSA 时,就不是平均切分特征图了,从上图可以看到需要将其切分成好几个尺寸不一的窗口,再在几个大小不一的窗口中计算自注意力,感觉实现起来会很麻烦的样子。Swin Transformer 的作者就想到能不能用同一套代码实现 W-MSA 与 SW-MSA 呢?答案是肯定的,只是要对 SW-MSA 做一些额外的处理。

如下图的中间部分所示,作者将不同的窗口做了平移,例如把 C 部分下移,把 B 部分右移,这样就把下图的左边那个样子转换成了下图的右边那个样子,也就是变成和 W-MSA 的切图一样的4个窗口了,也就可以直接复用 W-MSA 的自注意力计算的代码了。

当然,还需要解决另外一个问题,举个例子,下图的右边部分,3和6被分到了同一个窗口中,那么被标记成3的区域的 patch 在计算注意力的时候就会和被标记成6的区域的 patch 有交集。然而,我们只想要3和3的注意力、以及6和6的注意力,所以需要借助 mask 来遮挡我们不需要的注意力。

Cyclic Shift

还是拿区域3和区域6这个窗口举例,在下图中,假设输入特征图 H=W=14 patchesH=W=14 \text{ patches} ,我们先把区域3和区域6里面的每一个 patch 拿出来,3里面有 4×7=284 \times 7 = 28 个 patches,6里面有 3×7=213 \times 7 = 21 个 patches,一共是 28+21=4928+21=49 个 patches。

计算自注意力的时候,这49个 patches 要和自己的 transpose 相乘,得到一个尺寸为 (28+21)×(28+21)(28+21) \times (28+21) 的 attention matrix。如下图所示,attention matrix 里左上角的部分是 28×2828 \times 28 个区域3和区域3里的 patch 之间计算的注意力,右上角的部分是 28×2128 \times 21 个区域3和区域6里的 patch 之间计算的注意力,左下角的部分是 21×2821 \times 28 个区域6和区域3里的 patch 之间计算的注意力,右下角的部分是 21×2121 \times 21 个区域6和区域6里的 patch 之间计算的注意力。

而我们实际上只想保留 attention matrix 里左上角区域3和区域3里的 patch 之间计算的注意力,以及右下角区域6和区域6里的 patch 之间计算的注意力。所以要给 attention matrix 加上一个 mask matrix,在不要的部分值为-100,在保留的部分值为0,这样后面通过一个 softmax 的时候,值为-100的部分就为趋于0。

对于其他窗口的计算也是类似的,这里就不再赘述,感兴趣可以参考一下这篇文章:图解Swin Transformer

Masking in SW-MSA

W-MSA + SM-MSA 对比 MobileNet

本质上通过 W-MSA + SW-MSA 替代 MSA 是一个减少计算量的工作,其思想有点像 MobileNet 用深度卷积+逐点卷积代替普通的卷积。

如下图所示,MobileNet 先通过深度卷积核处理特征图,每个深度卷积核只作用于单个通道,再用逐点卷积补充通道之间的信息交流。

深度卷积+逐点卷积

相对位置编码

上文中对 W-MSA + SM-MSA 的讨论其实还缺了一部分,我们看自注意力的计算公式 Attention(Q,K,V)=Softmax(QKTd+B)V\text{Attention}(Q, K, V) = \text{Softmax}(\frac{QK^T}{\sqrt{d}} + B)V ,前文的讨论只涉及了 Q、K、V,这里面还有个 B,就是相对位置编码。

如下图所示,假设 W-MSA 的窗口尺寸为 2×22 \times 2 ,我们自注意力的计算就是在这个窗口内完成的。

W-MSA 的窗口

Swin Transformer 的作者在设计相对位置编码的时候也考虑到了节约计算资源,首先,通过如下代码生成一个可学习的 relative positional embedding table。

relative_position_bias_table = nn.Parameter(torch.zeros((2 * window_size[0] - 1) * (2 * window_size[1] - 1), num_heads))

也就是对于每一个 attention head,都生成一个 (2⋅window size−1)×(2⋅window size−1)(2 \cdot \text{window size} - 1) \times (2 \cdot \text{window size} - 1) 大小可学习的相对位置编码矩阵,在我们的例子中, window size=2\text{window size}=2 ,所以这个相对位置编码矩阵的大小为 3×33 \times 3 。这个相对位置编码矩阵包含了我们尺寸为 2×22 \times 2 的窗口中所有 patch 之间可能存在的相对位置,例如上下一格的关系,左右一格的关系,等等。

相对位置编码矩阵

接着,我们还需要一个 index 来检索相对位置编码矩阵中的值。举个例子,假设我们有一个相对位置编码 index table,如下图中左边那个小图所示,里面的元素代表了不同 patch 之间相对位置对应的 index 值。

例如, q0q_0 和 k2k_2 之间的关系,查看下图右边那个小图,就是上下的关系(0号 patch 在2号 patch 上边),其在下图左边那个小图中的 index 是1。同样的,q1q_1 和 k3k_3 之间的关系,查看下图右边那个小图,也是上下的关系(1号 patch 在3号 patch 上边),其在下图左边那个小图中的 index 也是1。这就是说, q0q_0 和 k2k_2 之间的相对位置编码 index 与 q1q_1 和 k3k_3 之间的相对位置编码 index 相同,都是1,我们就回到上图中的大小为 3×33 \times 3 的相对位置编码矩阵中寻找1号元素的值作为他们的相对位置编码。

相对位置编码矩阵检索

这个大小为 3×33 \times 3 的相对位置编码矩阵好理解,就是一个可学习的矩阵,那么这个相对位置编码 index table 如何获取呢?

相对位置编码 index table

首先,还是假设 W-MSA 的窗口尺寸为 2×22 \times 2 ,我们可以分别标出这个窗口中每个 patch 的x轴坐标和y轴坐标,如下图所示。

W-MSA 的窗口中patch的坐标

计算这个窗口中每个 patch 间所有相对位置的一个好方法就是用所有 patch 的x、y坐标做差值。利用广播机制,我们分别用x轴坐标向量减去x轴坐标向量的转置以及y轴坐标向量减去y轴坐标向量的转置。

坐标向量相减

得到下图的相对位置编码 index 矩阵。

坐标向量相减结果

然而,我们想要的相对位置编码 index 矩阵是可以从里面获取相对位置编码的 index 值,显然上图中的矩阵里的元素并不是合理的 index 数值。所以还要进一步处理。

首先,index 最好是从0开始,不应该有负值,于是我们做如下操作:

去负值

其次,x和y方向的相对位置编码应该是不一样的,于是我们对x坐标单独做如下操作:

x坐标处理

最后再相加一下,就得到了最后的相对位置编码 index 矩阵:

相加

到目前为止,Swin Transformer Block 里的主要内容就讨论完了,再看一下 Swin Transformer 的结构图,我们会发现 Swin Transformer Block 出来接了一个 Patch Merging 模块,上文中讨论的 W-MSA 和 SW-MSA 都是为了解决 Vision Transformer 显存占用的问题,而这个 Patch Merging 解决的是 Vision Transformer 另外一个主要问题,也就是通过改变特征图尺寸使得 Vision Transformer 可以处理不同尺度的特征。

Patch Merging

Patch Merging

Patch Merging 和 Pooling 非常相似。我们先对大小为 HWCHWC 特征图进行窗口大小为 2×22 \times 2 的提取,如下图所示。

Patch Merging 切分

接着就是把不同编号的 patch 分别提取出来,得到4个尺寸为 H2×W2×C\frac{H}{2} \times \frac{W}{2} \times C 的特征图,然后再将这个4个特征图在 channel 方向拼接起来,得到尺寸为 H2×W2×4C\frac{H}{2} \times \frac{W}{2} \times 4C 的特征图。

Patch Merging 切分拼接

注意下图 Stage 2,特征图经过 Patch Merging + Swin Transformer Block 后尺寸除以2,通道数乘以2,而上图中尺寸除以2,通道数却数乘以4。

Patch Merging

所以我们需要将其拉直,经过一个 Linear 层降维,再 Reshape 成二维,最后尺寸从 H2×W2×4C\frac{H}{2} \times \frac{W}{2} \times 4C 变成了 H2×W2×2C\frac{H}{2} \times \frac{W}{2} \times 2C 。

Patch Merging Linear + Reshape

对比 PVT

对 Swin Transformer 结构设计做一个小结:

  • 通过 Patch Merging 实现了类似于 Pooling 的操作逐层减小特征图尺寸
  • 在 Swin Transformer Block 中通过 W-MSA + SW-MSA 解决了显存占用问题

我们再对比一下 PVT,发现 PVT 也是用不同的方法解决了同样的问题:

  • 在每一个 Block 开头进行切图 + Linear,逐层减小特征图尺寸
  • 通过 Spatial Reduction Attention (SRA) 减少 Attention 计算的显存占用

PVT 结构

在减少显存占用这方面,Swin Transformer 其实更胜一筹,PVT 提出的 SRA 将 K、V 的尺寸从 HWCHWC 降低到 HWR2C\frac{HW}{R^2}C ,从而将计算复杂度从 O(n2)O(n^2) 降低到 O(n2R2)O(\frac{n^2}{R^2}) 。(本来在估计计算复杂度的时候 R2R^2 是应该忽略的,不过在这里缩小系数 R2R^2 相对于输入特征图尺寸 HWHW 是一个较大的值)

然而,显然在特征图尺寸 HWHW 较大时,PVR 的计算复杂度 O(n2R2)O(\frac{n^2}{R^2}) 还是会比 Swin Transformer 的 O(n)O(n) 高很多的。

PVT 的 SRA

实验对比

ImageNet-1K

如下图所示,在分类任务上,Swin Transformer 无论是 FLOPs 还是准确率都是优于其他两个 Vision Transformer 的算法的(ViT 与 DeiT)。

然而,对比起卷积神经网络,似乎没有明显优势,感觉比 EfficientNet 还要弱一些。估计 Swin Transformer 的优势不在分类任务。作者在论文里说的是“Compared with the state-of-the-art ConvNets, i.e. RegNet and EffificientNet, the Swin Transformer achieves a slightly better speed-accuracy trade-off”。

ImageNet-1K

ADE20K

在 ADE20K 分割任务是,Swin Transformer 几乎是横扫了。这也得益于其金字塔结构的设计,使其能够胜任分割这种像素级的任务。

ADE20K

另外,论文里还有关于相对位置编码和 SW-MSA 的消融实验,感兴趣可以去读一下论文,这里就不在赘述。

CNN 和 Swin Transformer 的异同

Swin Transformer 这篇论文的提出是里程碑式的,使得 Vision Transformer 在处理计算机视觉任务时不再有明显短板,也引发了许多关于CNN 和 Swin Transformer 的异同的思考。

相似点

  • 操作都在 window 里(卷积核通常为 3×33 \times 3 到 7×77 \times 7 大小的窗口 vs W-MSA)
  • 都是金字塔结构
  • 都是 pretraining + finetuning 的 pipeline
  • transfomer 的切图甚至都可以用 Conv2d 写(kernel_size=stride=patch_size)
nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size)

不同点

CNN 有很强的 inductive bias(归纳偏置),也就是针对不同的特定的问题,我们认为模型应该有某些特点,从而做出一系列针对模型设计的人为限制:

  • 处理图片时,每个位置的信息与周围的信息相关,因此设计出 Conv
  • 处理 NLP 任务时,输出的结果与单词的顺序有关,因此设计出 RNN
  • CNN 空间平移不变性的特点
  • CNN 网络底层处理简单的特征,逐层进行特征聚合处理的特征越来越复杂

Transformer 相对于 CNN 的 inductive bias 就少很多:

  • Patch 内信息建模 --> MLP
  • Patch 间信息建模 --> MSA

由于 Transformer 的 inductive bias 少,所以在处理图片任务时就需要更多的数据去训练,但是可能上限会更高。

  • 0
    点赞
  • 2
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值