07 Swin-Transformer

目录

一、 理论知识

1. 整体架构

2. stage中的Patch Merging

3. W-MSA和SW-MSA

3.1 W-MSA

3.2 SW-MSA

 4. Relative Position Bias

5. 参数配置

二、 网络复现

1. 网络搭建

2. train


个人学习笔记

资料来源: 12.1 Swin-Transformer网络结构详解_哔哩哔哩_bilibili

一、 理论知识

1. 整体架构

  •  从stage2开始,每经过stage,特征图减半,通道数翻倍
  •  a图堆叠的Swin-Transformer-Block是b图的2个

b图2个区别在于 W-MSA和SW-MSA

2. stage中的Patch Merging

stage1之前的 Pach Parition +++++ stage1的 Linear Embedding

在代码实现中就相当于一个Patch Merging

通过Patch Merging层后,feature map的高和宽会减半,深度会翻倍

3. W-MSA和SW-MSA

先看VIT模型和本文对比:

 VIT模型类似transformer中的Encoder:

最开始通过卷积下采样16倍,之后在整个过程特征图大小不变

Attention计算是整个特征图的计算,计算量大

SVIT:

最开始下采样4倍,之后每经过stage,特征图减半,最终下采样16倍;(多尺度特征

在每个特征图内划分窗口(windows);每个窗口内部元素进行Attention计算,而不是整个特征图计算(大大减小运算量

3.1 W-MSA

以下图为例:

假设下图中左图  是训练过程中某个特征图;VIT对图中任意两个点之间进行计算

W-MSA就是先划分为4个窗口,在每个窗口内部单独进行计算,减小计算量

这样做有一个问题:

窗口之间割裂,没有信息交互;为此要进行SW-MSA模块,即进行偏移的W-MSA。

所以Block是成对出现的,一个使用W-MSA,一个使用SW-MSA

3.2 SW-MSA

block成对出现,一个使用W-MSA,一个使用SW-MSA

第一个Block如下图左图划分(W-MSA); 那么下一个Block如下图右图划分(SW-MSA)。

SW-MSA可以信息融合,

比如右图第一排第二个窗口     融合了      左图最上面两个窗口信息

       右图第二排第二个窗口     融合了      左图最四个窗口的信息

右图的窗口是如何划分:

左上角窗口  分别向右侧和下方偏移了M/2个像素(M特征图大小)

 SW-MSA如何计算?

1. 对划分出来的9窗口进行标记

 

2. 窗口移动: 0 1 2移到最下面,之后3 6 0 移到最右面

移动完后,4是一个单独的窗口;将5和3合并成一个窗口;

7和1合并成一个窗口;8、6、2和0合并成一个窗口。这样又和原来一样是4个4x4的窗口

 3. 开始计算

计算 4 ;      计算5 3;      计算7 1;      计算8 6 2 0;

4单独计算; 5 3要通过mask;7 1要通过mask;8 6 2 0要通过mask

通过设置mask来隔绝不同区域的信息

SW-MSA 计算如下图:

对0像素点进行计算时,只让它和区域5内的像素进行匹配。那么我们可以将像素0与区域3中的所有像素匹配结果都减去100由于α的值很小,一般都是零点几的数,将其减去100后在通过SoftMax得到对应的权重都等于0了。所以对于像素0而言实际上还是只和区域5内的像素进行了计算。

 4. Relative Position Bias

计算公式的变化。使用了相对位置偏执后给够带来明显的提升

建议看视频,文字表述不太清楚

5. 参数配置

二、 网络复现

1. 网络搭建

1.  PatchMerging

#########  x.shape = [B, H, W, C] 
x0 = x[:, 0::2, 0::2, :]  # [B, H/2, W/2, C]                
# 每隔2个取一个,相当于大图片分成了4个小图片,分开
x1 = x[:, 1::2, 0::2, :]  # [B, H/2, W/2, C]
x2 = x[:, 0::2, 1::2, :]  # [B, H/2, W/2, C]
x3 = x[:, 1::2, 1::2, :]  # [B, H/2, W/2, C]
x = torch.cat([x0, x1, x2, x3], -1)  # [B, H/2, W/2, 4*C]    # 按通道拼接
x = x.view(B, -1, 4 * C)  # [B, H/2*W/2, 4*C]

2. 窗口划分

torch.contiguous()方法(详见超链接)

常与torch.permute()、torch.transpose()、torch.view()方法一起使用

###############x.shape==B, H, W, C     窗口M

x = x.view(B, H // window_size, window_size, W // window_size, window_size, C)
# permute: [B, H//M, M, W//M, M, C] -> [B, H//M, W//M, M, M, C]
# view: [B, H//M, W//M, M, M, C] -> [B*num_windows, M, M, C]
windows = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, C)

3. F.pad使用

torch.nn.functional.pad (input, pad, mode=‘constant’, value=0)

  • input:需要扩充的 tensor,可以是图像数据,亦或是特征矩阵数据;
  • pad:扩充维度,预先定义某维度上的扩充参数;
  • mode:扩充方法,有三种模式,分别表示常量(constant),反射(reflect),复制(replicate);
  • value:扩充时指定补充值,value只在mode=constant有效,即使用value填充在扩充出的新维度位置,而在reflect和replicate模式下,value不可赋值;

pad参数输入(详见超链接)

成对输入,每一对对一维操作;最后一维开始

4. torch.roll(详见超链接)

窗口平移shift

2. train

  • 混淆矩阵的画法

create_confusion_matrix.py 文件

将验证的结果添加到新建的类中;调用类中方法画出混淆矩阵

  • 挑选验证集中预测错误的图片

select_incorrect_samples.py文件

  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值