关于Umamba和mamba-SSM输入的小笔记 (2D情况)

本文介绍了如何在PyTorch中使用Umamba库实现MambaLayer,这是一种结合了层归一化、自注意力机制和局部卷积的模块,用于处理多通道特征图。MambaLayer在训练和推理阶段对(B,C,H,W)格式的输入进行操作,其中C表示特征通道数,H和W为图像尺寸。
摘要由CSDN通过智能技术生成

Umamba (https://arxiv.org/pdf/2401.04722.pdf) 里输入就用最正常的Pytorch格式就行,即 (B, C, H, W),

class MambaLayer(nn.Module):
    def __init__(self, dim, d_state = 16, d_conv = 4, expand = 2):
        super().__init__()
        self.dim = dim
        self.norm = nn.LayerNorm(dim)
        self.mamba = Mamba(
                d_model=dim, # Model dimension d_model
                d_state=d_state,  # SSM state expansion factor
                d_conv=d_conv,    # Local convolution width
                expand=expand,    # Block expansion factor
        )
    
    @autocast(enabled=False)
    def forward(self, x):
        if x.dtype == torch.float16:
            x = x.type(torch.float32)
        B, C = x.shape[:2]
        assert C == self.dim
        n_tokens = x.shape[2:].numel()
        img_dims = x.shape[2:]
        x_flat = x.reshape(B, C, n_tokens).transpose(-1, -2)
        x_norm = self.norm(x_flat)
        x_mamba = self.mamba(x_norm)
        out = x_mamba.transpose(-1, -2).reshape(B, C, *img_dims)

        return out

训练或推理时迭代的x, 是(B, C, H ,W), B 为batch size, C是channel dim, 比如浅层或许为1, 3, 深层512, H, W 为当前图像长宽。 

初始化MambaLayer时候, dim对应的是当前channel dim, 其实就是特征深度或者说厚度,

即 mamba_exmaple = MambaLayer(C).

我们再看看mamba本身库里, 

    def forward(self, hidden_states, inference_params=None):
        """
        hidden_states: (B, L, D)
        Returns: same shape as hidden_states
        """
        batch, seqlen, dim = hidden_states.shape

hidden_states是输入x, 在上面第一块代码里即为x_norm, x_norm的维度是 所谓(B,L, D), 这里有符号的不同, 这里的D其实就是上面C, 即为特征厚度,或说特征通道数。 L是HxW, 即把图像拉长成一维数据。

  • 3
    点赞
  • 2
    收藏
    觉得还不错? 一键收藏
  • 1
    评论
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值