VSSM VMamba实现

本文详细介绍了VSSM模型的构建过程,包括维度变换的参数设置、模型参数初始化方法、VSSBlock结构、patchembed模块以及下采样和分类器的设计。重点展示了如何通过SS2D和MLP分支来处理视觉信息并进行分类任务。
摘要由CSDN通过智能技术生成

VSSM

Mamba实现可以参照之前的
mamba_minimal系列
论文地址:
VMamba
论文阅读:
VMamba:视觉状态空间模型
代码地址:
https://github.com/MzeroMiko/VMamba.git
SS2D实现

以分类任务用到的VMamba为例。

维度变换

操作的具体参数定义见初始化

阶段维度
输入x [ B , C , H , W ] [B, C, H, W] [B,C,H,W]
embed [ B , H / 4 , W / 4 , C 1 ] [B, H/4, W/4, C_1 ] [B,H/4,W/4,C1]
阶段1 [ B , H / 4 , W / 4 , C 1 ] [B, H/4, W/4, C_1 ] [B,H/4,W/4,C1]
阶段2 [ B , H / 8 , W / 8 , C 2 ] [B, H/8, W/8, C_2 ] [B,H/8,W/8,C2]
阶段3 [ B , H / 16 , W / 16 , C 3 ] [B, H/16, W/16, C_3 ] [B,H/16,W/16,C3]
阶段4 [ B , H / 32 , W / 32 , C 4 ] [B, H/32, W/32, C_4 ] [B,H/32,W/32,C4]
分类器 [ B , 1000 ] [B, 1000 ] [B,1000]

在这里插入图片描述

初始化

参数定义说明
in_chans3输入图像的通道数
depths[2, 2, 9, 2]定义每层的VSS Block数
dims[96, 192, 384, 768]定义每层的输出通道数
downsample_versionv2下采样操作的版本
patchembed_versionv1图像嵌入
mlp_ratio4.0定义mlp隐藏维度缩放
ssm_d_state16ssm隐状态的维度
ssm_ratio2.0d_inner = d_state * ssm_ratio
ssm_initv0ssm初始化版本
forward_typev2ssm前向版本

模型参数初始化

大部分参数即SS2D,VSS块中的参数由定义的ssm初始化版本初始化,剩下的线性层和归一化层参数由下面的函数初始化。

    def _init_weights(self, m: nn.Module):
        if isinstance(m, nn.Linear):
            trunc_normal_(m.weight, std=.02)
            if isinstance(m, nn.Linear) and m.bias is not None:
                nn.init.constant_(m.bias, 0)
        elif isinstance(m, nn.LayerNorm):
            nn.init.constant_(m.bias, 0)
            nn.init.constant_(m.weight, 1.0)

模型搭建

def_make_layer

构建VSSM的4个阶段,即4层,VSSBlock本身并不改变输入的尺寸,因此需要下采样模块将输出维度变换为下一阶段的输入维度

def _make_layer(
        dim=96, 
        drop_path=[0.1, 0.1], 
        use_checkpoint=False, 
        norm_layer=nn.LayerNorm,
        downsample=nn.Identity(),
        # ===========================
        ssm_d_state=16,
        ssm_ratio=2.0,
        ssm_dt_rank="auto",       
        ssm_act_layer=nn.SiLU,
        ssm_conv=3,
        ssm_conv_bias=True,
        ssm_drop_rate=0.0, 
        ssm_init="v0",
        forward_type="v2",
        # ===========================
        mlp_ratio=4.0,
        mlp_act_layer=nn.GELU,
        mlp_drop_rate=0.0,
        **kwargs,
    ):
        depth = len(drop_path)
        blocks = []
        for d in range(depth):
            blocks.append(VSSBlock(
                hidden_dim=dim, 
                drop_path=drop_path[d],
                norm_layer=norm_layer,
                ssm_d_state=ssm_d_state,
                ssm_ratio=ssm_ratio,
                ssm_dt_rank=ssm_dt_rank,
                ssm_act_layer=ssm_act_layer,
                ssm_conv=ssm_conv,
                ssm_conv_bias=ssm_conv_bias,
                ssm_drop_rate=ssm_drop_rate,
                ssm_init=ssm_init,
                forward_type=forward_type,
                mlp_ratio=mlp_ratio,
                mlp_act_layer=mlp_act_layer,
                mlp_drop_rate=mlp_drop_rate,
                use_checkpoint=use_checkpoint,
            ))
        
        return nn.Sequential(OrderedDict(
            blocks=nn.Sequential(*blocks,),
            downsample=downsample,
        ))
def _make_downsample

默认下采样版本v2

下采样模块,卷积核大小为2,步长为2,长宽变为原来的一半,维度变化由输入维度dim和输出维度out_dim定义。从而通过卷积变换为下一层的输出维度。

    def _make_downsample(dim=96, out_dim=192, norm_layer=nn.LayerNorm):
        return nn.Sequential(
            Permute(0, 3, 1, 2),
            nn.Conv2d(dim, out_dim, kernel_size=2, stride=2),
            Permute(0, 2, 3, 1),
            norm_layer(out_dim),
        )

patch embed

默认嵌入版本v1,对输入图像进行embed

输入x维度 [ B , 3 , H , W ] [B, 3, H, W] [B,3,H,W],嵌入后通道维变为96, H = H p a t c h _ s i z e H = \frac{H}{patch\_size} H=patch_sizeH W = W p a t c h _ s i z e W = \frac{W}{patch\_size} W=patch_sizeW [ B , 96 , H 4 , W 4 ] [B, 96, \frac{H}{4}, \frac{W}{4}] [B,96,4H,4W]

 def _make_patch_embed(in_chans=3, embed_dim=96, patch_size=4, patch_norm=True, norm_layer=nn.LayerNorm):
        return nn.Sequential(
            nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size, bias=True),
            Permute(0, 2, 3, 1),
            (norm_layer(embed_dim) if patch_norm else nn.Identity()), 
        )

第一至四阶段

这几个阶段的差别在于每一层的VSSBlock数不同,由depths定义分别为 [2, 2, 9, 2],输出维度由dims定义分别为[96, 192, 384, 768]。其组成元素除一阶段外,均在VSSBlock前包含下采样模块以变换维度。

具体介绍见VSSBlock

分类器

池化后长宽变为1,则变量尺寸变为 [ B , C , 1 , 1 ] [B, C, 1, 1] [B,C,1,1],展平后变为 [ B , C ] [B, C] [B,C]最后线性投影到类别维度1000

[ B , 1000 ] [B, 1000] [B,1000]

self.classifier = nn.Sequential(OrderedDict(
            norm=norm_layer(self.num_features), # B,H,W,C
            permute=Permute(0, 3, 1, 2),
            avgpool=nn.AdaptiveAvgPool2d(1),
            flatten=nn.Flatten(1),
            head=nn.Linear(self.num_features, num_classes),
        ))

VSSBlock

对于ssm分支来说,其输入输出维度不变为(B, H, W, d_model) ,对于mlp分支来说中间的隐藏维度根据mlp_ratio参数定义会有所增加,但是最后又会映射为原来的维度,因此整体上并不改变输入的维度。

def __ init__

主要分为两个分支ssm分支和mlp分支

ssm分支

主要组成部分是SS2D块
SS2D实现

if self.ssm_branch:
            self.norm = norm_layer(hidden_dim)
            self.op = _SS2D(
                d_model=hidden_dim, 
                d_state=ssm_d_state, 
                ssm_ratio=ssm_ratio,
                dt_rank=ssm_dt_rank,
                act_layer=ssm_act_layer,
                # ==========================
                d_conv=ssm_conv,
                conv_bias=ssm_conv_bias,
                # ==========================
                dropout=ssm_drop_rate,
                # =========================
                initialize=ssm_init,
                forward_type=forward_type,
            )      

图中的SS2D和SS2D类的定义有偏差,简单来说是是包含SS2D块加一个残差连接,图中所示SS2D应表示状态空间模型SSM部分,即VSS块相比SS2D块只增加了残差连接和入口的归一化。如果定义了MLP分支,VSS块的输出还会经过一个残差连接的两层MLP

在这里插入图片描述

mlp分支
 if self.mlp_branch:
            self.norm2 = norm_layer(hidden_dim)
            mlp_hidden_dim = int(hidden_dim * mlp_ratio)
            self.mlp = Mlp(in_features=hidden_dim, hidden_features=mlp_hidden_dim, act_layer=mlp_act_layer, drop=mlp_drop_rate, channels_first=False)
            
 class Mlp(nn.Module):
    def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.,channels_first=False):
        super().__init__()
        out_features = out_features or in_features
        hidden_features = hidden_features or in_features

        Linear = partial(nn.Conv2d, kernel_size=1, padding=0) if channels_first else nn.Linear
        self.fc1 = Linear(in_features, hidden_features)
        self.act = act_layer()
        self.fc2 = Linear(hidden_features, out_features)
        self.drop = nn.Dropout(drop)

    def forward(self, x):
        x = self.fc1(x)
        x = self.act(x)
        x = self.drop(x)
        x = self.fc2(x)
        x = self.drop(x)
        return x

def forward

    def _forward(self, input: torch.Tensor):
        if self.ssm_branch:
            if self.post_norm:
                x = input + self.drop_path(self.norm(self.op(input)))
            else:
                x = input + self.drop_path(self.op(self.norm(input)))
        if self.mlp_branch:
            if self.post_norm:
                x = x + self.drop_path(self.norm2(self.mlp(x))) # FFN
            else:
                x = x + self.drop_path(self.mlp(self.norm2(x))) # FFN
        return x

     
    
    def forward(self, input: torch.Tensor):
        if self.use_checkpoint:
            return checkpoint.checkpoint(self._forward, input)
        else:
            return self._forward(input)

  • 29
    点赞
  • 67
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值