第6周学习:Vision Transformer + Swin Transformer
一、Vision Transformer
- 回顾Self-Attention机制
W是可以进行学习更新的参数
输入输出的 feaure 维度是相同的,是一个 plug-and-play 模块。
- 简单而言,纯vit模型由三个模块组成:
(1)Linear Projection of Flattened Patches(Embedding层)
Embedding层是将高维数据转化成低维,对于标准的Transformer模块,要求输入的是token(向量)序列,即二维矩阵[num_token, token_dim],第一个参数意思为展平后的尺寸,第二个参数为所使用卷积核个数。
例如:
(2)Transformer Encoder
Transformer Encoder其实就是重复堆叠Encoder Block L次。
Dropout:将神经元间的连接随机删除。
Droppath:将深度学习模型中的多分支结构子路径随机”删除。
Layer Norm,初步用于NLP领域,不同的是,BN是对一个batch数据的每个channel进行Norm处理,但LN是对单个数据的指定维度进行Norm处理,与batch无关。(3)MLP Head(最终用于分类的层结构)
上面通过Transformer Encoder后输出的shape和输入的shape是保持不变的,所以我们只需要提取出[class]token生成的对应结果
二、Swin Transformer
在Swin Transformer中使用了Windows Multi-Head Self-Attention(W-MSA)的概念,特征图的下采样率改变,并将特征图划分成了多个不相交的区域(Window),并且MSA只在每个Window内进行。相对于Vision Transformer中直接对整个(Global)特征图进行Multi-Head Self-Attention,这样做的目的是能够减少计算量的,尤其是在浅层特征图很大的时候。并提出了 Shifted Windows Multi-Head Self-Attention(SW-MSA)的概念,通过此方法能够让信息在相邻的窗口中进行传递。
- 网络结构
- 源码中Patch Partition和Linear Embeding就是直接通过一个卷积层实现的,和之前Vision Transformer中讲的 Embedding层结构一模一样。
-
Patch Merging
可以看到Stage 2,3,4每个Transformer Block之前都要执行该操作。如下图所示,假设输入Patch Merging的是一个4x4大小的单通道特征图(feature map),Patch Merging会将每个2x2的相邻像素划分为一个patch,然后将每个patch中相同位置(同一颜色)像素给拼在一起就得到了4个feature map。接着将这四个feature map在深度方向进行concat拼接,然后在通过一个LayerNorm层。最后通过一个全连接层在feature map的深度方向做线性变化,将feature map的深度由C变成C/2。通过这个简单的例子可以看出,通过Patch Merging层后,feature map的高和宽会减半,深度会翻倍。
-
SW-MSA
引入Windows Multi-head Self-Attention模块是为了减少计算量,引入Shfited Window,W-MSA和SW-MSA是成对使用的,L层是W-MSA,则L+1层是SW-MSA,依靠跨域window解决了不同窗口之间无法进行信息交流的问题。
每个像素的操作均类似
个人理解,经过该操作,在self-attention中可以使原本不同的window乘以相同的W获得q k v,从而不同window之间的信息得以使用。为了更高效的计算,引入mask操作,使原来相分隔的区域仍保持独立计算。在计算完后还要把数据给挪回到原来的位置上(例如上述的A,B,C区域)。 -
Relative Position Bias
原因是使用了相对位置偏执后给够带来明显的提升。其本质就是希望attention map进一步有所偏重。因为attention map中某个值越低,经过softmax之后,该值会更低。对最终特征的贡献就低。
比如现在有一个2 × 2 2 \times 22×2的特征图。设置windows size 为(2,2),我们可以看看relative_position_index长什么样:
torch.Size([4, 4])
tensor([
[4, 3, 1, 0],
[5, 4, 2, 1],
[7, 6, 4, 3],
[8, 7, 5, 4]
])
注意一个参数对应一组qkv
以第一行为例,第一个元素为4,第二个元素为3;对应就是方格图中标号为1和2的位置。
其实就是第一个query和第一个key都在标号1的位置,所以相对位置为0,则都使用参数表的第4个偏置;而第2个key中的元素,位置在标号1的右边一格,用参数表的第3个参数。依次,第3个key中的元素使用第1个偏置。