论文阅读:Swin Transformer

Swin Transformer: Hierarchical Vision Transformer using Shifted Windows

0 摘要

  1. 本文提出了一种分层Transformer,其表示是用移位窗口计算的,通过将自注意力计算限制在不重叠的本地窗口同时,还允许跨窗口连接来提高效率。这种分层架构在各种尺度上建模很灵活,并且具有相对于图像大小的线性计算复杂度。
  2. 能和视觉任务兼容,包括图像分类和密集预测任务,目标检测和语义分割。
  3. 性能相当好,证明了基于 Transformer 的模型作为视觉主干的潜力。 可以作为Transformer的backbone。
  4. 代码:https://github.com/microsoft/Swin-Transformer

1 引言

CV方向通用方法是CNN,NLP领域存在Transformer方法。本文寻求扩展 Transformer的适用性,使其可以作为CV的backbone。

CV和NLP的差异:

  1. NLP中Transformers的基本元素是单词token,视觉元素在规模上会有很大差异,现有的Transformer方法中token是固定规模,不适用于CV。
  2. 图像中像素的分辨率很高。 对于高分辨率图像上的Transformer来说难以处理,因为其自注意力的计算复杂性与图像大小成二次方。

文本提出了一个通用的Transformer主干,称为 Swin Transformer,它构建分层特征图并且对图像大小具有线性计算复杂度

image-20210804202832878

  1. 分层特征图:通过从小尺寸的patch开始并逐渐与相邻patch合并到更深层中构建分层表示
  2. 线性计算复杂度:通过在图像分割的非重叠窗口内局部计算自注意力来实现。每个窗口中的patch数量是固定的,因此复杂度与图像大小成线性关系。

image-20210804203515633

Swin Transformer 的一个关键设计元素是它在连续自注意力层之间的窗口分区的移动。移动的窗口桥接前一层的窗口,提供它们之间的连接。

2 相关研究

3 方法

3.1 整体架构

image-20210804204311199

  1. 先将图片分成不重叠的块,每个块成为token,其特征是原始像素RGB值的串联,然后经过一个线性嵌入层投影到维度C。

  2. 在这些token上应用若干修正自注意力计算的Swin Transformer 块。Transformer 块保持令牌的数量,与线性嵌入一起被称为“阶段 1”。

  3. 随着网络变深,通过块合并层来减少token的数量。

    1. 第一个块合并层连接每组 2 × 2 2×2 2×2相邻块的特征,并在 4 C 4C 4C 维连接特征上应用线性层( 4 C 4C 4C 2 C 2C 2C)。 token数减少4倍,并且输出维度设置为 2 C 2C 2C。 之后应用 Swin Transformer 块进行特征转换,保持 H 8 × W 8 \frac{H}{8}× \frac{W}{8} 8H×8W。 补丁合并和特征转换的第一个块表示为“第 2 阶段”。
    2. 该过程重复两次,分别为“第 3 阶段”和“第 4 阶段”,输出分辨率分别为 H 16 × W 16 \frac{H}{16}× \frac{W}{16} 16H×16W H 32 × W 32 \frac{H}{32}× \frac{W}{32} 32H×32W
  4. 这些阶段共同产生一个分层表示,具有与典型卷积网络相同的特征图分辨率。因此,所提出的架构可以方便地替换将骨干网络置于现有方法中,用于各种视觉任务。

Swin Transformer块

Swin Transformer 是通过将 Transformer 模块中的标准多头自注意力 (MSA,multi-head self attention) 模块替换为基于移动窗口的模块而构建的,其他层保持不变。由一个基于移动窗口的MSA 模块组成,后跟一个 2层MLP,其间具有GELU非线性。 在每个MSA模块和每个 MLP 之前应用一个正则化层,在每个模块之后应用一个残差连接

3.2 基于移动窗口的自注意力

非重叠窗口中的自注意力

为了有效建模,提出在局部窗口内计算自注意力。窗口以不重叠的方式均匀地划分图像。 假设每个窗口包含 M × M M×M M×M个块,全局MSA 模块和基于 h × w 块图像的窗口的计算复杂度为:
Ω ( M S A ) = 4 h w C 2 + 2 ( h w ) 2 C Ω ( W − M S A ) = 4 h w C 2 + 2 M 2 h w C \Omega(MSA)=4hwC^2+2(hw)^2C\\ \Omega(W-MSA)=4hwC^2+2M^2hwC\\ Ω(MSA)=4hwC2+2(hw)2CΩ(WMSA)=4hwC2+2M2hwC
(这里的计算复杂度就是qkv那一套理论)可以看到,全局自注意力复杂度正比于块数的平方,而非重叠窗口的自注意力对块数是线性的。

连续块中的移位窗口分区

基于窗口的自注意力模块缺乏跨窗口的连接,限制了能力。 为了在保持非重叠窗口的高效计算的同时引入跨窗口连接,作者提出了一种移位窗口分区方法,在连续的 Swin Transformer 块中的两个分区配置之间交替

image-20210804213543560

第一个模块使用从左上角像素开始的常规窗口分区策略,将 8 × 8 8 × 8 8×8 特征图均匀地划分为大小为 4 × 4 ( M = 4 ) 4 × 4 (M = 4) 4×4(M=4) 的$ 2 × 2$ 窗口。 然后,下一个模块采用从前一层的窗口配置偏移的窗口配置,通过将窗口从规则分区的窗口中移动 ( ⌊ M 2 ⌋ , ⌊ M 2 ⌋ ) (\lfloor\frac{M}{2}\rfloor,\lfloor\frac{M}{2}\rfloor) (2M,2M)个像素。使用移位窗口分区方法,连续的 Swin Transformer 块计算为:

image-20210804213318224

符号说明见上图。

移位配置的高效批量计算

移位窗口分区的一个问题是它会产生更多的窗口,从$\lceil \frac{h}{W} \rceil \times \lceil \frac{h}{W} \rceil $ 到 $(\lceil \frac{h}{W} \rceil+1) \times (\lceil \frac{h}{W} \rceil+1) $ ,有的窗口会比 M × M M×M M×M小。

作者提出了一种更有效的批量计算方法,通过向左上方向循环移位。 在这种移位之后,在特征图中一个批量窗口可能由几个不相邻的子窗口组成,因此采用屏蔽机制将自注意力计算限制在每个子窗口内。使用循环移位,批处理窗口的数量与常规窗口分区的数量相同,因此也是有效的。

image-20210804220937688

相对位置偏置

image-20210804221057129

B ∈ R M 2 × M 2 B\in R^{M^2\times M^2} BRM2×M2 M M M是一个窗口内的块的个数。由于沿每个轴的相对位置在 $[−M + 1, M −1] $范围内,参数化一个更小的偏置矩阵 B ^ ∈ R ( 2 M − 1 ) × ( 2 M − 1 ) \hat B \in R ^{(2M−1)×(2M−1)} B^R(2M1)×(2M1),并且 B B B的值取自 B ^ \hat B B^

3.3 架构的变体

image-20210804221402783

4 实验

4.1 图像分类

image-20210804222012494

4.2 目标检测

image-20210804222100096

4.3 语义分割

image-20210804222116558

4.4 消融实验

移位窗口

image-20210804222306416

相对位置偏置
自注意方法

image-20210804222322506

5 总结

本文介绍了 Swin Transformer,一种新的视觉 Transformer,它产生分层特征表示并具有线性计算复杂度。Swin Transformer实现了先进的性能。 希望 Swin Transformer 在各种视觉问题上的强大表现将鼓励视觉和语言信号统一建模。作为 Swin Transformer 的一个关键元素,基于平移窗口的自注意力在视觉问题上被证明是有效和高效的,作者也期待研究其在自然语言处理中的应用。

6 源代码

6.1 Part Partition

注意,源代码里面使用的卷积层,把Patch Partition和Linear Embeeding作用合到了一起

6.2 Swin Transformer block

window 划分:

def window_partition(x, window_size):
    """
    Args:
        x: (B, H, W, C)
        window_size (int): window size

    Returns:
        windows: (num_windows*B, window_size, window_size, C)
    """
    B, H, W, C = x.shape
    x = x.view(B, H // window_size, window_size, W // window_size, window_size, C)
    windows = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, C)
    return windows

移位窗口:

# cyclic shift
        if self.shift_size > 0:
            shifted_x = torch.roll(x, shifts=(-self.shift_size, -self.shift_size), dims=(1, 2))
        else:
            shifted_x = x

mask:

img_mask = torch.zeros((1, H, W, 1))  # 1 H W 1
  
mask_windows = window_partition(img_mask, self.window_size)  # nW, window_size, window_size, 1
mask_windows = mask_windows.view(-1, self.window_size * self.window_size)

代码太长了不想看了,跑起来了就行

  • 2
    点赞
  • 4
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值