(简单易学)将mamba2添加到你的模型(NLP | CV-2d)中【PyTorch】

与PS:本文方法不依赖包的版本环境,也不需要编译,即插即用,简单易学

1. mamba2模块简介

1.1. mamba2纯torch实现代码

参考git大牛的mamba2-minimal实现https://github.com/tommyip/mamba2-minimal/blob/main/mamba2.py

PS:当你的Python版本低于3.10的时候上述代码可能会有报错,这些报错都是python的类型定义说明,可以参考其它博客进行调整,或者直接删除报错位置的类型定义即可,只保留赋值前后的有效值。

关于更便捷的多维度双向mamba2的实现,可以参考我的这篇博客

BiMamba2的纯PyTorch实现的多维实现。任意模态(signal | nlp | cv | audio| vedio)的快速Mamba2缝合神器(支持1d,2d,3d,...Nd)-CSDN博客

1.2. 有效模块

需要用到的类Mamba2 ,Mamba2Config

2. 使用说明

2.1. 使用样例(定义)

mamba2_block = Mamba2(Mamba2Config(d_model=768),device= 'cuda:0')

2.2. 参数介绍(定义)

你可以调整的参数d_model,device

d_model 对应特征向量的大小,必须是64的倍数(不懂的看调用部分的参数介绍)

device需要在模块创建的时候提前指定,不支持定义后调整

2.3. 使用样例(调用)

x = torch.randn(2, 64, 768)  # (batch, seqlen, d_model)
y,h = mamba2_block(x)  # same shape as x

2.4. 参数介绍(调用)

在这里的分为输入参数和输出参数

输入参数

x:这里要求seqlen是句子长度(图像对应像素数量),d_model是特征向量大小(图像对应通道数),另外这里要求seqlen和d_model均可被64整除

输出参数

y:这里和x的形状是保持完全一致的

h:这是隐藏层的全部参数,包含了conv_state和ssd_state

3. NLP使用样例代码

import torch
import torch.nn.functional as F
from mamba2 import Mamba2, Mamba2Config


def _pad64(x):
    b, l, d = x.size()
    pad_len = (64 - l % 64) % 64
    pad_dim = (64 - d % 64) % 64
    return pad_len, pad_dim


def pad(x, pad_len, pad_dim):
    return F.pad(x, (0, pad_dim, 0, pad_len), mode='constant', value=0)


def unpad(x, pad_len, pad_dim):
    return x[:, :-pad_len, :-pad_dim]


# 假设原始的NLP信号的长度为L,特征维度为D,均不可被64整除
L = 999
D = 96
batch = 12

mamba2_block = Mamba2(Mamba2Config(d_model=(D+63)//64 * 64), device='cuda:0')

x = torch.randn(batch, L, D).cuda()
print("x.shape:", x.shape, ) # x.shape: torch.Size([12, 999, 96])


pad_info = _pad64(x)
_x = pad(x, *pad_info)
print("_x.shape:", _x.shape, )# _x.shape: torch.Size([12, 1024, 128])

_y, h = mamba2_block(_x)
print("_y.shape:", _y.shape, ) # _y.shape: torch.Size([12, 1024, 128])

y = unpad(_y, *pad_info)
print("y.shape:", y.shape) # y.shape: torch.Size([12, 999, 96])

运行结果

x.shape: torch.Size([12, 999, 96])
_x.shape: torch.Size([12, 1024, 128])
_y.shape: torch.Size([12, 1024, 128])
y.shape: torch.Size([12, 999, 96])

4. CV2D 使用样例代码

import torch
import torch.nn.functional as F
from mamba2 import Mamba2, Mamba2Config
from einops import rearrange


def _pad64(x):
    l, d = x.shape[-2:]
    pad_len = (64 - l % 64) % 64
    pad_dim = (64 - d % 64) % 64
    return pad_len, pad_dim

def _pad8(x):
    l, d = x.shape[-2:]
    pad_len = (8 - l % 8) % 8
    pad_dim = (8 - d % 8) % 8
    return pad_len, pad_dim

def pad(x, pad_len, pad_dim):
    return F.pad(x, (0, pad_dim, 0, pad_len), mode='constant', value=0)


def unpad(x, pad_len, pad_dim):
    return x[..., :-pad_len, :-pad_dim]


# 假设原始的CV2d数据的宽高为 H,W,通道数为C 均不可被64整除
C = 141
H = 37
W = 78
batch = 1

_C = (C + 63) // 64 * 64
_H = (H + 7) // 8 * 8  # 8 的 倍数
_W = (W + 7) // 8 * 8  # 8 的 倍数

conv_in = torch.nn.Conv2d(C, _C, 1, 1, 0).cuda()
mamba2_block = Mamba2(Mamba2Config(d_model=_C), device='cuda:0')
conv_out = torch.nn.Conv2d(_C, C, 1, 1, 0).cuda()

x = torch.randn(batch, C, H, W).cuda()
print("x.shape:", x.shape)
x = conv_in(x)
print("x.shape:", x.shape)
pad_info = _pad8(x)
_x = pad(x, *pad_info)
print("_x.shape:", _x.shape, )
_x = rearrange(_x, 'b c h w -> b (h w) c')
print("_x.shape:", _x.shape, )
_y, h = mamba2_block(_x)
print("_y.shape:", _y.shape, )
_y = rearrange(_y, 'b (h w) c -> b c h w', h=_H, w=_W)
print("_y.shape:", _y.shape, )
y = unpad(_y, *pad_info)
print("y.shape:", y.shape)
y = conv_out(y)
print("y.shape:", y.shape)

运行结果

x.shape: torch.Size([1, 141, 37, 78])
x.shape: torch.Size([1, 192, 37, 78])
_x.shape: torch.Size([1, 192, 40, 80])
_x.shape: torch.Size([1, 3200, 192])
_y.shape: torch.Size([1, 3200, 192])
_y.shape: torch.Size([1, 192, 40, 80])
y.shape: torch.Size([1, 192, 37, 78])
y.shape: torch.Size([1, 141, 37, 78])

特别说明

1、CV图片的像素数都比较大,若需要添加到大分辨率的特征层建议将输入和输出的Conv的卷积核和步长都调大,从而实现分块的效果。

2、分块之后的图像分辨率发生了变化,需要重新根据分块后的尺寸调用函数计算新的Pad相关参数

  • 12
    点赞
  • 7
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值