目录
个人学习笔记
资料来源: 12.1 Swin-Transformer网络结构详解_哔哩哔哩_bilibili
一、 理论知识
1. 整体架构
- 从stage2开始,每经过stage,特征图减半,通道数翻倍
- a图堆叠的Swin-Transformer-Block是b图的2个
b图2个区别在于 W-MSA和SW-MSA
2. stage中的Patch Merging
stage1之前的 Pach Parition +++++ stage1的 Linear Embedding
在代码实现中就相当于一个Patch Merging
通过Patch Merging层后,feature map的高和宽会减半,深度会翻倍
3. W-MSA和SW-MSA
先看VIT模型和本文对比:
VIT模型类似transformer中的Encoder:
最开始通过卷积下采样16倍,之后在整个过程特征图大小不变;
Attention计算是整个特征图的计算,计算量大。
SVIT:
最开始下采样4倍,之后每经过stage,特征图减半,最终下采样16倍;(多尺度特征)
在每个特征图内划分窗口(windows);每个窗口内部元素进行Attention计算,而不是整个特征图计算(大大减小运算量)
3.1 W-MSA
以下图为例:
假设下图中左图 是训练过程中某个特征图;VIT对图中任意两个点之间进行计算
W-MSA就是先划分为4个窗口,在每个窗口内部单独进行计算,减小计算量
这样做有一个问题:
窗口之间割裂,没有信息交互;为此要进行SW-MSA模块,即进行偏移的W-MSA。
所以Block是成对出现的,一个使用W-MSA,一个使用SW-MSA
3.2 SW-MSA
block成对出现,一个使用W-MSA,一个使用SW-MSA
第一个Block如下图左图划分(W-MSA); 那么下一个Block如下图右图划分(SW-MSA)。
SW-MSA可以信息融合,
比如右图第一排第二个窗口 融合了 左图最上面两个窗口信息
右图第二排第二个窗口 融合了 左图最四个窗口的信息
右图的窗口是如何划分:
左上角窗口 分别向右侧和下方各偏移了M/2个像素(M特征图大小)
SW-MSA如何计算?
1. 对划分出来的9窗口进行标记
2. 窗口移动: 0 1 2移到最下面,之后3 6 0 移到最右面
移动完后,4是一个单独的窗口;将5和3合并成一个窗口;
7和1合并成一个窗口;8、6、2和0合并成一个窗口。这样又和原来一样是4个4x4的窗口
3. 开始计算
计算 4 ; 计算5 3; 计算7 1; 计算8 6 2 0;
4单独计算; 5 3要通过mask;7 1要通过mask;8 6 2 0要通过mask
通过设置mask来隔绝不同区域的信息
SW-MSA 计算如下图:
对0像素点进行计算时,只让它和区域5内的像素进行匹配。那么我们可以将像素0与区域3中的所有像素匹配结果都减去100由于α的值很小,一般都是零点几的数,将其减去100后在通过SoftMax得到对应的权重都等于0了。所以对于像素0而言实际上还是只和区域5内的像素进行了计算。
4. Relative Position Bias
计算公式的变化。使用了相对位置偏执后给够带来明显的提升
建议看视频,文字表述不太清楚
5. 参数配置
二、 网络复现
1. 网络搭建
1. PatchMerging
######### x.shape = [B, H, W, C]
x0 = x[:, 0::2, 0::2, :] # [B, H/2, W/2, C]
# 每隔2个取一个,相当于大图片分成了4个小图片,分开
x1 = x[:, 1::2, 0::2, :] # [B, H/2, W/2, C]
x2 = x[:, 0::2, 1::2, :] # [B, H/2, W/2, C]
x3 = x[:, 1::2, 1::2, :] # [B, H/2, W/2, C]
x = torch.cat([x0, x1, x2, x3], -1) # [B, H/2, W/2, 4*C] # 按通道拼接
x = x.view(B, -1, 4 * C) # [B, H/2*W/2, 4*C]
2. 窗口划分
torch.contiguous()方法(详见超链接)
常与torch.permute()、torch.transpose()、torch.view()方法一起使用
###############x.shape==B, H, W, C 窗口M
x = x.view(B, H // window_size, window_size, W // window_size, window_size, C)
# permute: [B, H//M, M, W//M, M, C] -> [B, H//M, W//M, M, M, C]
# view: [B, H//M, W//M, M, M, C] -> [B*num_windows, M, M, C]
windows = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, C)
3. F.pad使用
torch.nn.functional.pad (input, pad, mode=‘constant’, value=0)
- input:需要扩充的 tensor,可以是图像数据,亦或是特征矩阵数据;
- pad:扩充维度,预先定义某维度上的扩充参数;
- mode:扩充方法,有三种模式,分别表示常量(constant),反射(reflect),复制(replicate);
- value:扩充时指定补充值,value只在mode=constant有效,即使用value填充在扩充出的新维度位置,而在reflect和replicate模式下,value不可赋值;
pad参数输入(详见超链接)
成对输入,每一对对一维操作;最后一维开始
4. torch.roll(详见超链接)
窗口平移shift
2. train
- 混淆矩阵的画法
create_confusion_matrix.py 文件
将验证的结果添加到新建的类中;调用类中方法画出混淆矩阵
- 挑选验证集中预测错误的图片
select_incorrect_samples.py文件