【论文精读】一文看懂Swin Transformer!Shifted Window到底是个啥?Mask之后还和原来一样?

广告位:

图像拼接论文精读专栏 —— 图像拼接领域论文全覆盖(包含数据集),省时省力读论文,带你理解晦涩难懂的论文算法,学习零散的知识和数学原理,并学会写图像拼接领域的论文(介绍、相关工作、算法、实验、结论、并附有参考文献,不用一篇一篇文章再找)

图像拼接论文源码精读专栏 —— 图像拼接有源码的论文全覆盖(有的自己复现),帮助你通过源码进一步理解论文算法,助你做实验,跑出拼接结果,得到评价指标RMSE、SSIM、PSNR等,并寻找潜在创新点和改进提升思路。

超分辨率重建专栏 —— 从SRCNN开始,带你读论文,写代码,复现结果,找创新点,完成论文。手把手教,保姆级攻略。帮助你顺利毕业,熟练掌握超分技术。

有需要的同学可以点上面链接看看。



前言

论文链接:Swin Transformer: Hierarchical Vision Transformer using Shifted Windows

源码地址:https://github.com/microsoft/Swin-Transformer

来看看为什么Swin Transformer能屠榜吧!


Abstract

Transformer做视觉有两个大的挑战:

  1. 目标尺寸多变。不像NLP任务中token大小基本相同,目标检测中的目标尺寸不一,用单层级的模型很难有好的效果。
  2. 图片的高分辨率。尤其是在分割任务中,高分辨率会使得计算复杂度呈现输入图片大小的二次方增长,这显然是不能接受的。

为了解决上述问题,就有了Swin Transformer。顾名思义,Hierarchical(多层级)解决第一个问题;Shifted Windows(滑窗)解决第二个问题。
在这里插入图片描述
如图所示,Swin Transformer通过融合图片块构建多层级的特征图。同时,使计算复杂度与输入图片线性相关,一个window包含若干个patch,仅在window内部计算self-attention。由于window的patch固定,所以计算复杂度与输入图片线性相关。这也就是Shifted Windows,是Swin的缩写,也是本篇文章最精彩的部分。

Shifted Windows Attention

虽然在window内部计算self-attention可能大大降低模型的复杂度,但是不同window无法进行信息交互,从而表现力欠缺。为了更好的增强模型的表现能力,引入Shifted Windows Attention。Shifted Windows是在连续的Swin Transformer blocks之间交替移动的。

Shifted window partitioning in successive blocks

一般的Shifted window partition操作如下图:
在这里插入图片描述

  • 每一个小块叫做一个patch
  • 每一个深色方块框起来的叫一个local window
  • 在每一个local window中计算self-attention
  • 连续两个Blocks之间转换,第一个Block平分feature map,第二个Block从( ⌊ M 2 ⌋ \lfloor\frac{M}{2}\rfloor 2M, ⌊ M 2 ⌋ \lfloor\frac{M}{2}\rfloor 2M)像素有规律地取代前一层的windows。
  • windows数量变化: ⌈ h M ⌉ \lceil\frac{h}{M}\rceil Mh× ⌈ w M ⌉ \lceil\frac{w}{M}\rceil Mw → \rightarrow ( ⌈ h M ⌉ + 1 \lceil\frac{h}{M}\rceil+1 Mh+1)×( ⌈ w M ⌉ + 1 \lceil\frac{w}{M}\rceil+1 Mw+1) 例子中是2×2变成了3×3
  • 但这种方法有一个致命的问题,就是在windows变化的过程中,有些window_size小于 M × M M×M M×M,这就导致了需要用padding方法将其补齐使每个window大小相同,虽然解决了,但增加了计算量。

window partition源码:

def window_partition(x, window_size):
    """
    Args:
        x: (B, H, W, C)
        window_size (int): window size

    Returns:
        windows: (num_windows*B, window_size, window_size, C)
    """
    B, H, W, C = x.shape
    x = x.view(B, H // window_size, window_size, W // window_size, window_size, C)
    windows = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, C)
    return windows

阅读源码后发现,源码中也没有实现windows由4个变成9个的操作,而且当window_size为奇数时会报错,也不必过分纠结于此,因为实际的操作是通过下面更有效地方法计算的。

Efficient batch computation

通过给Attention加mask实现,限制自注意力计算量,在子窗口中计算。
在这里插入图片描述

cyclic shitf

详解见文章:【Pytorch小知识】torch.roll()函数的用法及在Swin Transformer中的应用(详细易懂)

源码中的部分:

# cyclic shift
if self.shift_size > 0:
   shifted_x = torch.roll(x, shifts=(-self.shift_size, -self.shift_size), dims=(1, 2))
   attn_mask = mask_matrix
else:
   shifted_x = x
   attn_mask = None

Masked MSA

这应该是本篇论文最精彩的想法,通过mask使shifted window attention和window attention在相同窗口下的计算结果等价,完美解决了上面的window不一致问题,可以对非规则window计算attention。这部分论文中没有阐述,只能结合代码看一下:

源码

# calculate attention mask for SW-MSA
Hp = int(np.ceil(H / self.window_size)) * self.window_size
Wp = int(np.ceil(W / self.window_size)) * self.window_size
img_mask = torch.zeros((1, Hp, Wp, 1), device=x.device)  # 1 Hp Wp 1
h_slices = (slice(0, -self.window_size),
            slice(-self.window_size, -self.shift_size),
            slice(-self.shift_size, None))
w_slices = (slice(0, -self.window_size),
            slice(-self.window_size, -self.shift_size),
            slice(-self.shift_size, None))
cnt = 0
for h in h_slices:
   for w in w_slices:
       img_mask[:, h, w, :] = cnt
          cnt += 1

mask_windows = window_partition(img_mask, self.window_size)  # nW, window_size, window_size, 1
mask_windows = mask_windows.view(-1, self.window_size * self.window_size)
attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2)
attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(attn_mask == 0, 

自己测试代码

还是以4×4输入为例说明。

window_size=2
shift_size=1
#x = torch.randn(1,8,8,3)
#x.shape
H = 4
W = 4
  • h,w,window_size,shift_size分别代表window的高,宽,M和 ⌊ M 2 ⌋ \lfloor\frac{M}{2}\rfloor 2M

去掉self后的代码:

# calculate attention mask for SW-MSA
Hp = int(np.ceil(H / window_size)) * window_size
Wp = int(np.ceil(W / window_size)) * window_size
img_mask = torch.zeros((1, Hp, Wp, 1), device=x.device)  # 1 Hp Wp 1
#print("img_mask:",img_mask)
h_slices = (slice(0, -window_size),
            slice(-window_size, -shift_size),
            slice(-shift_size, None))
w_slices = (slice(0, -window_size),
            slice(-window_size, -shift_size),
            slice(-shift_size, None))

#print("h_slices:",h_slices)
#print("w_slices:",w_slices)

cnt = 0
for h in h_slices:
    for w in w_slices:
        img_mask[:, h, w, :] = cnt
        print("img_mask",img_mask)
        cnt += 1

mask_windows = window_partition(img_mask, window_size)  # nW, window_size, window_size, 1
#print("mask_windows:",mask_windows)
mask_windows = mask_windows.view(-1, window_size * window_size)
#print("mask_windows:",mask_windows)
attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2)
#print("mask_windows:",attn_mask)
attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(attn_mask == 0, float(0.0))
#print("mask_windows:",attn_mask)
  • 首先是一些划分操作,根据shift_size将window划分成各个区域
  • 二重循环是赋值操作,根据上一步划分的区域,使每个区域的值相同
  • 通过window_partition得到cyclic shift,再通过变换和masked_fill得到最终的attention-mask
mask_windows: tensor([[[   0.,    0.,    0.,    0.],
         [   0.,    0.,    0.,    0.],
         [   0.,    0.,    0.,    0.],
         [   0.,    0.,    0.,    0.]],

        [[   0., -100.,    0., -100.],
         [-100.,    0., -100.,    0.],
         [   0., -100.,    0., -100.],
         [-100.,    0., -100.,    0.]],

        [[   0.,    0., -100., -100.],
         [   0.,    0., -100., -100.],
         [-100., -100.,    0.,    0.],
         [-100., -100.,    0.,    0.]],

        [[   0., -100., -100., -100.],
         [-100.,    0., -100., -100.],
         [-100., -100.,    0., -100.],
         [-100., -100., -100.,    0.]]])

其中,四个mask对应关系为:
在这里插入图片描述
在这里插入图片描述
在这里插入图片描述
在这里插入图片描述
在这里插入图片描述
这是feature map在roll操作后的结果,将每个部分拉直进行QKT操作,即可得到对应的mask结果。参考图解Swin Transformer中的Attention Mask部分:
在这里插入图片描述
在这里插入图片描述
得到上边代码的mask结果。

 if mask is not None:
            nW = mask.shape[0]
            attn = attn.view(B_ // nW, nW, self.num_heads, N, N) + mask.unsqueeze(1).unsqueeze(0)
            attn = attn.view(-1, self.num_heads, N, N)
            attn = self.softmax(attn)
        else:
            attn = self.softmax(attn)

softmax之后,值为-100的元素会被忽略,从而达到mask的效果,仅得到window中有效的部分的attention。

再reverse回去就达到和原先计算结果一致的目的。类似于CNN中提取特征局部计算的过程。

Swin-T

在这里插入图片描述

Swin Transformer block

在这里插入图片描述

总结

  1. self-attention的计算
  2. local window attention的计算
  3. shifted window attention的计算
  4. 创新点:用window的概念将CNN中局部性计算的思想引入到transformer中

推荐参考文章

  1. 图解Swin Transformer
  2. Swin Transformer各机制详细推导

没有硬件条件,需要云服务的同学可以扫码看看:
请添加图片描述

  • 24
    点赞
  • 77
    收藏
    觉得还不错? 一键收藏
  • 打赏
    打赏
  • 14
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 14
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

十小大

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值