swin transformer 模块理解


前言

【个人学习笔记记录,如有错误,请指正】


配置文件使用 swin_small_patch4_windows7_224.yaml 文件,batch_size = 4

一、Patch Embedding

【Patch embedding】其实就是将输入的 224 * 224 大小的图像,经过【卷积】和【LayerNorm】操作,将图像缩放为 56 56 大小的特征图。然后将特征图reshape 为 (4, 3136, 96)形状,这里的 4 为【batch_size】,3136 = 5656,96 为特征图的通道数。

在这里插入图片描述

二、swin transformer block

【swin transformer block】其实就是下面的流程图。
在这里插入图片描述
这里主要对 【W-MSA】和 【SW-MSA】进行理解。

1.torch.roll 操作

在进行【roll】操作之前,需要将特征图的形状变为(B, H, W, C),即【4, 56, 56, 96】。

【注意:】,这里的【roll】操作是针对【SW-MSA】才有的。
代码如下:

if self.shift_size > 0:
    shifted_x = torch.roll(x, shifts=(-self.shift_size, -self.shift_size), dims=(1, 2))
else:
    shifted_x = x

示意图如下:
在这里插入图片描述
源码中,将特征图移动(-3, -3),如上图就是最后的特征图最后的形状。然后将这个新的特征图,进行窗口的划分,然后进行注意力操作,

2. window_partition 操作

代码如下:

x_windows = window_partition(shifted_x, self.window_size)

def window_partition(x, window_size):
    B, H, W, C = x.shape
    x = x.view(B, H // window_size, window_size, W // window_size, window_size, C)
    windows = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, C)
    return windows

示意图如下:
在这里插入图片描述
左边是将特征图划分为7*7 大小的窗口,右边是整个【window_partition】操作的x 的形状变化。

3. W-MSA

将上面的 7 * 7 的特征图,做注意力操作,其中输入的特征图形状为 7 * 7 = 49 和通道数 96,32 是因为多头注意力机制,这里是 3 头注意力机制。

注意力流程示意图如下:
在这里插入图片描述

代码里的 WindowAttention 的流程大致就是这样子,(这里的位置编码,代码不是很懂,有明白的可以解释解释),其中 mask 机制是在 SW-MSA 中使用到的。

4. SW-MSA

swin transformer 中 为了解决每个窗口之间的交互,引入了对特征图的偏移(偏移量为 3),但是引入偏移之后,源特征图中的窗口数量就变多了,这样就使得计算量变大。这样就引入了 mask 方法。
在这里插入图片描述

下面是对 mask 掩码生成的过程,这里以特征图大小为 6 * 6 ,窗口大小为 3 * 3,shift 值为 2为例。

原代码做如下修改:

import torch
import matplotlib.pyplot as plt


def window_partition(x, window_size):
    B, H, W, C = x.shape
    x = x.view(B, H // window_size, window_size, W // window_size, window_size, C)
    windows = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, C)
    return windows

window_size = 3
shift_size = 2
H, W = 6, 6
img_mask = torch.zeros((1, H, W, 1))  # 1 H W 1
h_slices = (slice(0, -window_size),
            slice(-window_size, -shift_size),
            slice(-shift_size, None))
w_slices = (slice(0, -window_size),
            slice(-window_size, -shift_size),
            slice(-shift_size, None))
cnt = 0
for h in h_slices:
    for w in w_slices:
        img_mask[:, h, w, :] = cnt
        cnt += 1

mask_windows = window_partition(img_mask, window_size)  # nW, window_size, window_size, 1
mask_windows = mask_windows.view(-1, window_size * window_size)

attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2)
attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(attn_mask == 0, float(0.0))

plt.matshow(img_mask[0, :, :, 0].numpy())
plt.text(0, 0, '0')
plt.matshow(attn_mask[0].numpy())
plt.matshow(attn_mask[1].numpy())
plt.matshow(attn_mask[2].numpy())
plt.matshow(attn_mask[3].numpy())

plt.show()

对于任意一张大小的特征图,都需要做相同大小的掩码模板。
(这里以 6 * 6 大小的特征图,窗口大小为 3 * 3,shift 为 2 为例)
生成 mask 模板示意图如下:
在这里插入图片描述
左右两个图示对应的。颜色不同表示不同的掩码区域。

将上述的掩码区域进行如下代码操作:

# img_mask 的形状为(56,56)
# maks_windows 的形状为(64,7,7)
mask_windows = window_partition(img_mask, self.window_size)
# maks_windows 的形状为(64,49)
mask_windows = mask_windows.view(-1, self.window_size * self.window_size)
# 这样相减,就形成了维度为(64, 49, 49)
attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2)
# 使用 -100.0 填充非零部分,其余部分使用 0.0 填充
attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(attn_mask == 0, float(0.0))

给出 attn_mask 的示意图(示意图还是以 6*6 大小的特征图画的):
这里至于为什么这样做,大家可以看出,下面 0 的区域就是上面 mask 模板中的所有部分,非零区域就是多出来的部分。

在这里插入图片描述
这里就做好了位置掩码操作,只需要在正向传播的过程中和特征图进行相加操作即可。这里对特征图的相加之后会经过【softmax】操作,因为特征图的值通常很小,加上 -100 之后,就会成为-100附近的值,然后经过 softmax 就变成了趋近于 0 的数,这样达到了掩码的作用。

5. downsample 操作

在原论文中,给出了这个流程图,可以注意到特征图大小的减少,和通道数的增加。
在这里插入图片描述
下面给出 downsample 下采样代码

def forward(self, x):
    H, W = self.input_resolution
    B, L, C = x.shape
    assert L == H * W, "input feature has wrong size"
    assert H % 2 == 0 and W % 2 == 0, f"x size ({H}*{W}) are not even."
    x = x.view(B, H, W, C)
    x0 = x[:, 0::2, 0::2, :]  # B H/2 W/2 C
    x1 = x[:, 1::2, 0::2, :]  # B H/2 W/2 C
    x2 = x[:, 0::2, 1::2, :]  # B H/2 W/2 C
    x3 = x[:, 1::2, 1::2, :]  # B H/2 W/2 C
    x = torch.cat([x0, x1, x2, x3], -1)  # B H/2 W/2 4*C
    x = x.view(B, -1, 4 * C)  # B H/2*W/2 4*C
    x = self.norm(x)
    x = self.reduction(x)
    return x

这里的下采样操作和 【YOLOV5】中的 【Focus】的操作是类似的,大家可以看我的另一篇博客: [YOLOV5 模块理解]
就是将特征图的像素横纵每隔一个取出一个,然后得到四个高宽减半的特征图,然后进行拼接。

总结

这里对 swin transformer 模型的主要模块进行了理解。其中最复杂的就是 mask 部分,感觉还是不太清楚。

  • 2
    点赞
  • 8
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
Transformer模块和Swin Transformer模块都是用于自然语言处理和计算机视觉任务的深度学习模型。它们之间的主要区别在于结构和应用领域。 Transformer模块是一种基于自注意力机制的神经网络模型,最初被提出用于机器翻译任务。它由编码器和解码器组成,通过多层堆叠的自注意力层和前馈神经网络层来捕捉输入序列中的上下文信息。Transformer模块的关键思想是通过自注意力机制来建立输入序列中各个位置之间的依赖关系,从而实现对序列的全局建模。Transformer模块在自然语言处理任务中取得了很大的成功,并被广泛应用于机器翻译、文本生成、语言理解等领域。 Swin Transformer模块是一种基于Transformer的计算机视觉模型,专门用于图像分类任务。与传统的Transformer模块不同,Swin Transformer模块引入了局部窗口机制,将输入图像划分为一系列局部窗口,并在每个窗口内进行自注意力计算。这种局部窗口机制可以有效减少计算复杂度,并且在保持全局感知能力的同时,增强了模型对局部细节的建模能力。Swin Transformer模块在计算机视觉任务中取得了很好的性能,尤其在大规模图像分类任务上表现出色。 总结来说,Transformer模块主要应用于自然语言处理任务,而Swin Transformer模块则是专门为计算机视觉任务设计的一种变种。它们在结构和应用领域上存在一些差异,但都基于自注意力机制,并具有良好的建模能力和性能表现。

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值