SwinTransformer模型及代码解读

网络整体结构

patchpatition: 用一个 4x4 大小的窗口对输入图像进行分割,分割后对每一个窗口,在 channel 方向进行展平,因此经过patchpartion 之后,特征图的宽高变为原来的 1/4 Linear embedding: 对每一个
channe 都进行一个 learnorm 处理。
每一个 stage 都会重复 Swin Transformer blocks N 次( N是偶数)。
Patch partition:
如上图,通过一个 4x4 大小的窗口对特征图进行分割,之后,对每一个窗口在深度方向 ( 按通道维度 ) 进行
展平( 通道之后的全部变成 1维)。变成
也就是说将每个像 素沿深度方向进行拼接,因为每个像素都为RGB 三通道,因此最终的 channel 48 。 在通过Linear Embedding 对通道进行调整,将通道变为 c

Patch Merging(其实就是进行下采样,将传入的特征图的HW2倍,channel翻倍)

首先经过 W-MSA 之后的黄色特征图的窗口划分如图所示
W-MSA 的窗口又经过偏移之后(右下移动 M/2 M window_size
由于上述 windows 变为了 9 个,如下图所示
经过循环移位之后 (A C 先去下方, B A 再去右方 )
此时就可以划分为 4 window
在后续的进行 MSA 自注意力之前,要加上一个 masked ,如下图
上述使用了掩码之后,就可以得到只有自己和自己之间的自注意力结果。

Relative position bias

B 就是偏移
大的特征图是将每一块在行的位置上展平得到的。

SwinTransformer

初始化

这里 num_features embed_dim 8

Patchembedpatch partion+linear Embedding

上述就是 patch partition (将特征图通过卷积划分成一个一个的 patch 之后,将 patch 按照通道维度进行展平)

PatchMerging

BasicLayer类实现swin_transformer中的每一个stage

create_mask方法

举例:使用特征图为 9x9 ,窗口为 3x3 的大小时。
第一步划分 winow ,使用的 window 尺寸 M=3
上述左侧的三个 { 就是 h_slices = (slice(0, -self.window_size), slice(-self.window_size, -self.shift_size), slice(-self.shift_size, None))里的三个切片,第一个切片从( 0 -3 ),取不到 -3
第二个切片从( -3 -1 )取不到 -1
第三个切片(-1 到末尾)
右图是
此处得到的 img_mask

window_partition函数

window_reverse函数

一开始通过 window_partition 得到 9 个窗口,接着通过 view 方法,将窗口展平。
attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2) # [nW, 1,
Mh*Mw] - [nW, Mh*Mw, 1]
这里的 tensor 相减会 涉及广播机制 ,对于 第一个矩阵 而言,会将第二个维度的数据 复制 Mh*Mw
相当于将上图右边的矩阵的每一个行向量复制 Mh*Mw 次( 9 次)。 对于第二个矩阵,会在最后一个维度将数据复制 Mh*Mw 次对于第一个矩阵的广播机制,相当于将上述的行向量复制了Mh*Mw ( 一个窗口内像素的个数 )
举例:
将最后一行复制 9 次,如图左所示:
对于第二个矩阵,是将最后一个维度上的数据复制 9 次(因为加了一维,对每一个行向量来说又多了一个新的维度),因此,每一行上的每一个数要复制9 次,得到右边的行向量。再将左边和右边进行相减( attention 的过程 )。
相减之后得到如右图所示,右图表示同一区域的用 0 表示,不同区域的是非 0
SwinTransformerBlock 构建每一个 block 的类

MLP

回到 SwinTransformerBlock 类的 forward 函数中

WindowAttention类(实现了W-MSASW-MSA的部分功能)

上述是生成 relative_position_index 的过程。假设 window_size 2 coords_h 喝和 ·coords_w 均为 0,1 coords = torch.stack(torch.meshgrid([coords_h, coords_w], indexing="ij")) # [2, Mh, Mw]
生成网格 coords_flatten = torch.flatten(coords, 1) # [2, Mh*Mw] 展平,从Mh 维度展平
第一行是 featuremap 上每一个像素对应的行标,第二行是 feature map上每一个像素对应的列标。
相减之后的结果如上图。得到的就是相对位置索引的矩阵。
得到 relative_persition_index
forward 函数
  • 33
    点赞
  • 22
    收藏
    觉得还不错? 一键收藏
  • 2
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 2
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值