论文:https://arxiv.org/abs/2106.13230
源码:https://github.com/SwinTransformer/Video-Swin-Transformer
1 网络结构
在DHW三维上构建window进行self-attention提取,所以同时提取了spatial和temporal两个维度的关联性
1.1 代码
Video-Swin-Transformer/blob/master/mmaction/models/backbones/swin_transformer.py
1.2 解析
-
SwinTransformer3D
patch_embed
:PatchEmbed3D
将输入三维信号切分成多个3d-patch,patch_size默认(2,4,4)
,对每个patch使用3d-conv进行特征提取并降采样padding
:对无法被patch_size整除维度进行填零paddingself.proj = conv3d(3, 96, kernel_size = patch_size, stride=patch_size)
:对输入特征进行三维卷积,即对每个patch_size大小窗口的输入进行特征提取,每个patch_size输出一个96维特征norm
(optional):fllatten
+transpose
+layer_norm
(对channel维度进行norm,即对每个patch_size的96维特征进行归一化)+transpose
-
pos_drop
:nn.Drop
-
self.layers
: depths[2, 2, 6, 2]
多个BasicLayer进行串联BasicLayer
进一步对上层输出信号切分成多个3d-window,window_size默认(8,7,7)
,对patch和patch之间的特征关联进行信息提取get_window_size((D,H,W), window_size=(8,7,7), shift_size=(4,3,3))
rearrange(x, 'b c d h w -> b d h w c')
self.attn_mask = compute_mask(Dp, Hp, Wp, window_size, shift_size, x.device)
根据输入尺度和window_size生成transformer中的mask,对非自身window的特征关联信息进行抑制
nn.ModuleList(SwinTransformerBlock3D(for i in range(depth)])
多个SwinTransformerBlock3D
进行串联(B,D,H,W,C)
nn.LayerNorm
F.pad
torch.roll(optional)
x_windows = window_partition
: shape(B*nW, Wd*Wh*Ww, C)
window切分attn_windows = self.attn(x_windows, mask=attn_mask)
:WindowAttention3D
对window内部进行self-attention特征提取, shape(B*nW, Wd*Wh*Ww, C)
nn.Linear(dim, dim * 3, bias=qkv_bias)
将输入升维三倍q, k, v = qkv[0], qkv[1], qkv[2]
提取K,Q,V特征
q * self.scale = head_dim ** -0.5
根据head_num进行缩放,防止multi-head大小对信号量影响过大attn = q @ k.transpose(-2, -1)
内积
attn + relative_position_bias: relative_position_bias_table
加入位置编码(防止特征顺序对transformer模块失效,不参与学习)attn.view(B_ // nW, nW, self.num_heads, N, N) + mask
加入关联特征激活/抑制mask,这里mask就是之前提取的self.attn_mask
self.softmax(attn) + self.attn_drop(attn)
Transformer标准模块x = (attn @ v)
Transformer标准模块self.proj = nn.Linear(dim, dim) + self.proj_drop = nn.Dropout(proj_drop)
Transformer标准模块x = shortcut + self.drop_path(x)
FFN模块
-
downsample
:PatchMerging
对输出特征进行重排,H和W变为1/2(不对D进行降采样),channel会变成4倍- 对H和W进行间隔采样
norm
:nn.LayerNorm
nn.Linear(4 * dim, 2 * dim)
channel降维
-
rearrange(x, 'b d h w c -> b c d h w')
-
rearrange
+norm
+rearrange
Swin-trans参数膨胀
inflate_weights
patch_embed
中的conv3d
选择直接膨胀初始化conv2d
relative_position_bias_table
两种:膨胀初始化、中心初始化