【代码解析】mmaction2: Video Swin Transformer

论文:https://arxiv.org/abs/2106.13230
源码:https://github.com/SwinTransformer/Video-Swin-Transformer

在这里插入图片描述

1 网络结构

在DHW三维上构建window进行self-attention提取,所以同时提取了spatial和temporal两个维度的关联性

1.1 代码

Video-Swin-Transformer/blob/master/mmaction/models/backbones/swin_transformer.py

1.2 解析

  • SwinTransformer3D

    • patch_embed: PatchEmbed3D
      将输入三维信号切分成多个3d-patch,patch_size默认(2,4,4),对每个patch使用3d-conv进行特征提取并降采样
      • padding:对无法被patch_size整除维度进行填零padding
      • self.proj = conv3d(3, 96, kernel_size = patch_size, stride=patch_size):对输入特征进行三维卷积,即对每个patch_size大小窗口的输入进行特征提取,每个patch_size输出一个96维特征
      • norm(optional): fllatten + transpose + layer_norm(对channel维度进行norm,即对每个patch_size的96维特征进行归一化)+transpose
  • pos_drop: nn.Drop

  • self.layers : depths [2, 2, 6, 2] 多个BasicLayer进行串联

    • BasicLayer 进一步对上层输出信号切分成多个3d-window,window_size默认(8,7,7),对patch和patch之间的特征关联进行信息提取
      • get_window_size((D,H,W), window_size=(8,7,7), shift_size=(4,3,3))
      • rearrange(x, 'b c d h w -> b d h w c')
      • self.attn_mask = compute_mask(Dp, Hp, Wp, window_size, shift_size, x.device) 根据输入尺度和window_size生成transformer中的mask,对非自身window的特征关联信息进行抑制
        在这里插入图片描述
    • nn.ModuleList(SwinTransformerBlock3D(for i in range(depth)])多个SwinTransformerBlock3D进行串联 (B,D,H,W,C)
      在这里插入图片描述
      • nn.LayerNorm
      • F.pad
      • torch.roll(optional)
      • x_windows = window_partition: shape (B*nW, Wd*Wh*Ww, C) window切分
      • attn_windows = self.attn(x_windows, mask=attn_mask): WindowAttention3D 对window内部进行self-attention特征提取, shape (B*nW, Wd*Wh*Ww, C)
        • nn.Linear(dim, dim * 3, bias=qkv_bias) 将输入升维三倍
        • q, k, v = qkv[0], qkv[1], qkv[2] 提取K,Q,V特征
        1. q * self.scale = head_dim ** -0.5根据head_num进行缩放,防止multi-head大小对信号量影响过大
        2. attn = q @ k.transpose(-2, -1) 内积
        • attn + relative_position_bias: relative_position_bias_table 加入位置编码(防止特征顺序对transformer模块失效,不参与学习)
        • attn.view(B_ // nW, nW, self.num_heads, N, N) + mask 加入关联特征激活/抑制mask,这里mask就是之前提取的self.attn_mask
        • self.softmax(attn) + self.attn_drop(attn) Transformer标准模块
        • x = (attn @ v) Transformer标准模块
        • self.proj = nn.Linear(dim, dim) + self.proj_drop = nn.Dropout(proj_drop) Transformer标准模块
        • x = shortcut + self.drop_path(x) FFN模块
  • downsample: PatchMerging 对输出特征进行重排,H和W变为1/2(不对D进行降采样),channel会变成4倍在这里插入图片描述

    • 对H和W进行间隔采样
    • norm: nn.LayerNorm
    • nn.Linear(4 * dim, 2 * dim) channel降维
  • rearrange(x, 'b d h w c -> b c d h w')

  • rearrange + norm + rearrange

Swin-trans参数膨胀
inflate_weights

  • patch_embed 中的conv3d选择直接膨胀初始化conv2d
  • relative_position_bias_table 两种:膨胀初始化、中心初始化

2 实验结果

在这里插入图片描述

  • 5
    点赞
  • 24
    收藏
    觉得还不错? 一键收藏
  • 6
    评论
以下是使用3D卷积实现的Video Swin Transformer代码示例,供您参考: ```python import torch import torch.nn as nn import torch.nn.functional as F class VideoSwinTransformerBlock(nn.Module): def __init__(self, in_channels, out_channels, num_heads, window_size, drop_rate=0.0): super().__init__() self.norm1 = nn.LayerNorm(in_channels) self.attn = nn.MultiheadAttention(in_channels, num_heads) self.norm2 = nn.LayerNorm(in_channels) self.mlp = nn.Sequential( nn.Linear(in_channels, out_channels), nn.GELU(), nn.Dropout(drop_rate), nn.Linear(out_channels, in_channels), nn.Dropout(drop_rate) ) self.window_size = window_size def forward(self, x): # reshape input for 3D convolution b, t, c, h, w = x.size() x = x.view(b*t, c, h, w) # add padding to input for overlapping window p = self.window_size // 2 x = F.pad(x, (p, p, p, p), mode='reflect') # apply 3D convolution with overlapping window x = self.conv(x) x = x.unfold(2, self.window_size, 1).unfold(3, self.window_size, 1) x = x.permute(0, 2, 3, 4, 1, 5, 6).contiguous() x = x.view(b*t*h*w, -1, c) # apply transformer block x = self.norm1(x) attn_output, _ = self.attn(x, x, x) x = x + attn_output x = self.norm2(x) x = self.mlp(x) x = x.view(b*t, h, w, -1, c).permute(0, 3, 4, 1, 2).contiguous() return x class VideoSwinTransformer(nn.Module): def __init__(self, in_channels, out_channels, num_heads, window_sizes, num_layers, drop_rate=0.0): super().__init__() self.conv = nn.Sequential( nn.Conv3d(in_channels, out_channels, kernel_size=(1, 3, 3), stride=(1, 2, 2), padding=(0, 1, 1)), nn.BatchNorm3d(out_channels), nn.ReLU(inplace=True) ) self.blocks = nn.ModuleList([ VideoSwinTransformerBlock(out_channels, out_channels, num_heads, window_size, drop_rate=drop_rate) for window_size in window_sizes ]) self.norm = nn.LayerNorm(out_channels) self.pool = nn.AdaptiveAvgPool3d((1, 1, 1)) self.fc = nn.Linear(out_channels, 1000) def forward(self, x): x = self.conv(x) for block in self.blocks: x = block(x) x = self.norm(x) x = self.pool(x) x = x.flatten(1) x = self.fc(x) return x ``` 其中,`VideoSwinTransformerBlock`表示视频Swin Transformer的一个基本块,包含了一个注意力机制和一个多层感知机,同时使用了3D卷积来处理视频数据。`VideoSwinTransformer`则表示整个视频Swin Transformer模型,包含了多个基本块和全局平均池化层和全连接层。您可以根据自己的需求修改其中的参数和模型结构。

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 6
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值