Swin Transformer模型解读(附源码+论文)

Swin Transformer

论文链接:Swin Transformer: Hierarchical Vision Transformer using Shifted Windows

官方链接:Swin-Transformer

代码补充:Swin-Replenishment(因为官方给的代码我在win下单gpu运行会报错,所以我修改了一些代码使其能够运行)

总体流程

在深度学习的视觉任务中,都会先经过backbone提取图片的特征。优秀的backbone可以事半功倍。我学习过很多模型,它们用的backbone大部分都是ResNet(经典永不过时)。近几年也演化了一些新的backbone,比如EfficientNet(之前讲过,计算量较低且效果好)、VIT(之前讲过,基于Transformer的,不过计算量很大,后续有很多优化版本),而今天介绍的是Swin Transformer (据说效果吊打)。

在VIT中,会图像分割成固定大小的 Patch然后排成一个序列,然后通过计算全局自注意力,计算量巨大,遇到高分辨率图像直接跪下。其次,VIT采用固定的patch,因此无法适应多尺度特征。

Swin Transformer灵机一动,通过分层特征提取窗口自注意力机制 解决这些问题。分层特征提取如下图,在初始阶段将图片划分为很小的patch,然后逐步合并patch(相邻 Patch 合并,形成更大的感受野)。这样可以逐步降低分辨率,提取高层语义信息,同时减少计算量。

请添加图片描述

窗口自注意力机制 主要在Swin Transformer中,流程图如下,由Patch EmbeddingSwin Transformer Block Patch Merging 组成。

请添加图片描述

Patch Embedding就是VIT里的核心步骤,将图片划分为固定大小的patch,然后通过一个卷积将每个 Patch 映射到固定维度的特征向量。假设输入的图像数据为(224,224,3),通过一个卷积Conv2d(3,96,kernel_size(4,4),stride=(4,4))映射到一个序列(3136,96)。(注:96 是一个经验值,3136=224÷4×224÷4,kernel_size=patch_size,确保每个 Patch 的信息被映射到一个嵌入向量中,stride=patch_size,确保不重叠分割。)

Swin Transformer Block 是Swin Transformer里的核心步骤,如下图,它由W-MSASW_MSA 组成。W表示窗口、SW表示滑动窗口。
请添加图片描述

首先对输入的特征图分成窗口,比如输入特征图(56,56,96),窗口size为7,那么可以分成8×8个窗口,输出(64,7,7,96)。得到窗口后,计算每个窗口的注意力得分,那么就见到我们的老朋友qkv了。就是老规矩,先q@k,然后加上位置编码再@v,代码里详说,最终得到每个窗口的注意力得分。

这边就已经过完了W-MSA ,但是这个过程吧,窗口自己跟自己玩,不同窗口之间的一个关系没有体现出来。

问题:不同窗口之间的 token 无法直接通信,这限制了全局信息传播

| A  A  A | B  B  B | C  C  C |
| A  A  A | B  B  B | C  C  C |
| A  A  A | B  B  B | C  C  C |

因此作者后面又接了一个SW-MSA ,滑动窗口。让B参与A的窗口,C参与B的窗口,A参与C的窗口。但是这也有个问题啊,A和C之间并不是一个连续的位置关系。

问题:如果直接计算自注意力,模型可能会错误地学习跨窗口不连续的位置关系

| A  A  B | B  B  C | C  C  A |
| A  A  B | B  B  C | C  C  A |
| A  A  B | B  B  C | C  C  A |

SW-MSA的流程如下图,本来四块窗口的通过滑动后被划分为了九个窗口。
请添加图片描述

图片位移后的结果如下。

请添加图片描述

计算时仍然按4个窗口计算,但是很显然除了左上那个4,其它三个窗口都存在不连续的位置关系。怎么办?那就不计算那些不连续的数据呗,用一个mask矩阵屏蔽原始窗口间的无效连接,如下图。
请添加图片描述

OK,Swin Transformer Block 部分结束 ,接下来进行Patch MergingPatch Merging就是进行一个下采样,而且画个图就能理解,过程非常简单,如下图。
请添加图片描述

对H和W维度进行建个采样后拼接在一起,每次操作后H和W减一半,C多一倍,如下图红色框。
请添加图片描述

代码

官网有很多例子可以用,如下图,学习的话可以选第一个图像分类任务做一下。
请添加图片描述

代码不讲基础的部分,直接进入models/swin_transformer.py讲网络模块。配置一下参数进入debug,configs下有很多网络配置,选一个自己喜欢的。–local_rank 0表示用单gpu,不过设不设置都无所谓,因为官方代码就不支持win单GPU运行,得自己改点代码才行。–data-path是自己数据的路径。

--cfg configs/swin/swin_tiny_patch4_window7_224.yaml
--batch-size 16
--local_rank 0
--data-path imagenet

SwinTransformer

首先呢,debug到SwinTransformer的forward_features方法。

def forward_features(self, x):  # (b,3,224,224)
    x = self.patch_embed(x)  # (b,3136,96)
    if self.ape:
        x = x + self.absolute_pos_embed
    x = self.pos_drop(x)

    for layer in self.layers:
        x = layer(x)

    x = self.norm(x)  # B L C (b,49,768)
    x = self.avgpool(x.transpose(1, 2))  # B C 1 (b,768,1)
    x = torch.flatten(x, 1)  # (b,768)
    return x

self.patch_embed没啥好说的,就是一个卷积将图片数据分patch然后排成一排,输入(b,3,224,224)输出(b,3136,96)。(注:3136=224÷4×224÷4,通道维度3转为96)
请添加图片描述

主要是下面的for循环,将经历W-MSA和SW-MSA。我们跳入SwinTransformerBlock的forward方法中。

W-MSA

代码中if self.shift_size > 0后面接的就是SW-MSA,else接的是W-MSA。

def forward(self, x):  # (b,3136,96)
    H, W = self.input_resolution
    B, L, C = x.shape
    assert L == H * W, "input feature has wrong size"

    shortcut = x
    x = self.norm1(x)
    x = x.view(B, H, W, C)  # (b,56,56,96)

    # cyclic shift
    if self.shift_size > 0:
        if not self.fused_window_process:
            # [[ 1,  2,  3,  4],                 [[ 6,  7,  8,  5],
            #  [ 5,  6,  7,  8],       ==>        [10, 11, 12,  9],
            #  [ 9, 10, 11, 12], shift_size = 1   [14, 15, 16, 13],
            #  [13, 14, 15, 16]]                  [ 2,  3,  4,  1]]
            shifted_x = torch.roll(x, shifts=(-self.shift_size, -self.shift_size), dims=(1, 2))  # 窗口滑动
            # partition windows
            x_windows = window_partition(shifted_x, self.window_size)  # nW*B, window_size, window_size, C
        else:
            x_windows = WindowProcess.apply(x, B, H, W, C, -self.shift_size, self.window_size)
    else:
        shifted_x = x
        # partition windows
        x_windows = window_partition(shifted_x, self.window_size)  # nW*B, w_s, w_s, C (64×b,7,7,96)

    x_windows = x_windows.view(-1, self.window_size * self.window_size, C)  # nW*B, w_s*w_s, C (64×b,49,96)

    # W-MSA/SW-MSA
    attn_windows = self.attn(x_windows, mask=self.attn_mask)  # nW*B, w_s*w_s, C (64×b,49,96)

    # merge windows
    attn_windows = attn_windows.view(-1, self.window_size, self.window_size, C)  # (64×b,7,7,96)

    # reverse cyclic shift
    if self.shift_size > 0:
        if not self.fused_window_process:
            shifted_x = window_reverse(attn_windows, self.window_size, H, W)  # B H' W' C
            x = torch.roll(shifted_x, shifts=(self.shift_size, self.shift_size), dims=(1, 2))  # 窗口还原
        else:
            x = WindowProcessReverse.apply(attn_windows, B, H, W, C, self.shift_size, self.window_size)
    else:
        shifted_x = window_reverse(attn_windows, self.window_size, H, W)  # B H' W' C (b,56,56,96) 还原回去
        x = shifted_x
    x = x.view(B, H * W, C)  # (b,3136,96)
    x = shortcut + self.drop_path(x)

    # FFN
    x = x + self.drop_path(self.mlp(self.norm2(x)))

    return x

输入的size为(b,3136,96),经过标准化后重新view为(b,56,56,96)。(注:56×56=3136)。第一步是W-MSA,所以先执行else后的代码,window_partition会将图像划分窗口,具体进入函数看一下。

window_partition

def window_partition(x, window_size):
    B, H, W, C = x.shape
    x = x.view(B, H // window_size, window_size, W // window_size, window_size, C)  # (b,8,7,8,7,96)
    windows = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, C)  # (64×b,7,7,96)
    return windows  # 将原来56×56的窗口划分为为7×7的,然后排成一排

默认窗口size为7,现在输入的图像是56×56的,按7×7的窗口将划分8×8块窗口,最终输出(64×b,7,7,96)

=======================================================================

回到SwinTransformerBlock的forward继续。

将刚刚划分好的窗口先重新view一下为(64×b,49,96),然后带入到self.attn计算注意力分数,我们进入attn方法中看一下怎么个事。

WindowAttention

进入WindowAttention的forward方法中。

def forward(self, x, mask=None):
    B_, N, C = x.shape
    qkv = self.qkv(x).reshape(B_, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)  # (3,64×b,3,49,32)
    q, k, v = qkv[0], qkv[1], qkv[2]  # make torchscript happy (cannot use tensor as tuple)

    q = q * self.scale  # 缩放Q 防止点积结果过大,保持数值稳定
    attn = (q @ k.transpose(-2, -1))  # k:(49,32)->(32,49) q·k->(49,49)

    relative_position_bias = self.relative_position_bias_table[self.relative_position_index.view(-1)].view(  # 位置编码
        self.window_size[0] * self.window_size[1], self.window_size[0] * self.window_size[1], -1)  # Wh*Ww,Wh*Ww,nH
    relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous()  # nH, Wh*Ww, Wh*Ww
    attn = attn + relative_position_bias.unsqueeze(0)  # (64×b,3,49,49)

    if mask is not None:
        nW = mask.shape[0]
        attn = attn.view(B_ // nW, nW, self.num_heads, N, N) + mask.unsqueeze(1).unsqueeze(0)
        attn = attn.view(-1, self.num_heads, N, N)
        attn = self.softmax(attn)
    else:
        attn = self.softmax(attn)

    attn = self.attn_drop(attn)

    x = (attn @ v).transpose(1, 2).reshape(B_, N, C)  # (b,3,49,49)·(b,3,49,32)->(b,3,49,32)->(b,49,3,32)->(b,49,96)
    x = self.proj(x)
    x = self.proj_drop(x)
    return x

现在输入的维度为(64×b,49,96)。这里qkv矩阵是直接通过一个线性变换self.qkv得到,self.num_heads注意力头默认为3,得到输出(3,64×b,3,49,32)。(注:3为qkv三个矩阵,3为注意力头,49=7×7窗口大小,32=96÷3每个head特征)

然后分别通过qkv[0], qkv[1], qkv[2]取得qkv的矩阵。下面开始计算注意力分数。

self.scale是为了缩放q ,防止点积结果过大,保持数值稳定。然后与k点积,点积前将k转置一下,得到atten(64×b,3,49,49)。relative_position_bias是位置编码,这个就是套死公式,然后atten加上这个位置编码。现在还在W-MSA阶段,所以没有mask,直接对attn进行softmax。atten与v的转置点积,reshape得到(b,49,96),最后经过一个线性变换self.proj完成!

======================================================================

回到SwinTransformerBlock的forward继续。

将刚刚得到的重新view一下为(64×b,7,7,96)。window_reverse是将图像数据的size还原回原来的样子(64×b,7,7,96)->(b,56,56,96)。得到的结果view成(b,3136,96),最后做点线性变换做残差得到最终结果。

SW-MSA

W-MSA部分算是结束了,紧接着进入SW-MSA模块,重复的地方就不说了。还是在SwinTransformerBlock的forward方法中,这次将执行if self.shift_size > 0后的代码。

shift_size默认值为3,torch.roll就是用来窗口滑动的。我画了个图来理解它,加入shift_size为1的情况,它的效果如下图。

请添加图片描述

接下来不一样的点在计算注意力机制模块,WindowAttention的forward方法中。

if mask is not None:
    nW = mask.shape[0]
    attn = attn.view(B_ // nW, nW, self.num_heads, N, N) + mask.unsqueeze(1).unsqueeze(0)
    attn = attn.view(-1, self.num_heads, N, N)
    attn = self.softmax(attn)

这里的attn还加了一个mask矩阵,mask矩阵是怎么来的呢。

H, W = self.input_resolution
img_mask = torch.zeros((1, H, W, 1))  # 1 H W 1
h_slices = (slice(0, -self.window_size),
            slice(-self.window_size, -self.shift_size),
            slice(-self.shift_size, None))
w_slices = (slice(0, -self.window_size),
            slice(-self.window_size, -self.shift_size),
            slice(-self.shift_size, None))
cnt = 0
for h in h_slices:
    for w in w_slices:
        img_mask[:, h, w, :] = cnt
        cnt += 1

mask_windows = window_partition(img_mask, self.window_size)  # nW, window_size, window_size, 1
mask_windows = mask_windows.view(-1, self.window_size * self.window_size)
attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2)
attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(attn_mask == 0, float(0.0))

先创建一个全零的掩码张量img_mask,shape为 (1, H, W, 1),用来存储窗口索引。h_slices和w_slices是用来划分区域的切片(其实我没看懂,只是大概知道是干啥的)。cnt为不同的窗口区域编号,然后按刚刚的切片填充不同的索引值,用来区分不同区域。这样,相同窗口的像素会有相同的cnt值,不同窗口的像素值不同。将img_mask按7×7进行窗口划分为mask_windows,展平,计算像素间的差值矩阵,相同窗口内的元素差值为 0,不同窗口之间的差值不为 0。不同窗口的填充-100,相同的填充0。

因此atten加上mask矩阵,对于连续窗口的token不做变化,而对于不连续窗口的token加上-100,经过softmax近似忽略。

做完这些之后要进入下采样了。

PatchMerging

跳到PatchMerging的forward方法中看看怎么个事。

def forward(self, x):
    H, W = self.input_resolution
    B, L, C = x.shape
    assert L == H * W, "input feature has wrong size"
    assert H % 2 == 0 and W % 2 == 0, f"x size ({H}*{W}) are not even."

    x = x.view(B, H, W, C)

    x0 = x[:, 0::2, 0::2, :]  # B H/2 W/2 C
    x1 = x[:, 1::2, 0::2, :]  # B H/2 W/2 C
    x2 = x[:, 0::2, 1::2, :]  # B H/2 W/2 C
    x3 = x[:, 1::2, 1::2, :]  # B H/2 W/2 C
    x = torch.cat([x0, x1, x2, x3], -1)  # B H/2 W/2 4*C
    x = x.view(B, -1, 4 * C)  # B H/2*W/2 4*C

    x = self.norm(x)
    x = self.reduction(x)

    return x

代码简洁明了,非常清晰。按H和W每隔一个像素点取值,然后拼接在一起。标准化后再经过一个线性变换,结束!输出(b,784,192)。(注:784=56÷2×56÷2,即H和W都减半。C一开始叠了4倍,后面经过self.reduction减半了,所以192=96×4÷2)

======================================================================

后面再重复几次,每次HW减半,C两倍,over!

评论 3
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值