一、前言
如果下面的内容对您有帮助,请多多点赞、收藏、转发,谢谢大家的支持。同时博主创建了一个计算机视觉相关的科研互助群Q:950440005。另外,本人最近也对这些年的双模态检测工作做了总结,如果大家正在犯愁如何寻找双模态图像融合和目标检测相关的创新点,欢迎大家加群讨论👏👏👏
最近Mamba模型可谓是火遍全球,在23年12月份的时候,MambaV1正式发布。作者在论文中证明了该模型在推理和训练方面的优势。但和最开始的《Attention is All You Need》一样,对图像领域产生颠覆性影响的Transformer也是最先应用于文本序列领域。所以,为了拓展Mamba架构在图像领域的应用,今年也有许多新的工作,例如Vision Mamba和VMamba等。它们证实了Mamba架构在CV领域的巨大潜力。
在今年六月份的时候,Mamba的原作者继续公开了MambaV2模型。论文中从矩阵数学的角度深度剖析了SSM(State Space Model)与Self-Attention、Linear-Attention等注意力之间的深层关联。目前来看,MambaV2框架确有大一统之势(官方Github代码已有近 20k Starred!)。它或许就是下一个颠覆整个计算机视觉的新架构。所以,我们需要紧跟潮流,抓紧学习起来!
MambaV1论文:《Mamba: Linear-Time Sequence Modeling with Selective State》
MambaV2论文:《Transformers are SSMs: Generalized Models and Efficient Algorithms》
Mamba官方代码:https://github.com/state-spaces/mamba
二、令人痛苦的环境配置过程
我们在配置Mamba环境时,往往都是按照官方github的操作,直接使用 pip install causal-conv1d
和 pip install mamba-ssm
。但百分之九十九的情况下,大家都会遇到各种奇葩的Bug,例如下载网络请求超时、CUDA版本不匹配、C++编译失败等等。往往需要被折磨个一两天。那有没有什么办法能够尽量避免大部分的坑呢?当然有!只需要按照本博客一步一步来,直接跳过所有奇葩Bug!!!
三、解决方案
下载 causal-conv1d 和 mamba-ssm 这两个库的离线.whl文件,然后直接pip install 文件名.whl进行安装即可!
我配置的虚拟环境重要库版本是(建议大家将以下的各种库版本与我保持高度一致!):
- torch:2.3.1+cu118
- torchaudio:2.3.1+cu118
- torchvision:0.18.1+cu118
- triton:2.3.1
- transformers:4.43.3
- causal-conv1d:1.4.0
- mamba-ssm :2.2.2
- cuda-nvcc:1.8.89
(1)首先,我们需要先新建一个Conda虚拟环境:
conda create -n mamba python=3.10
(2)然后进入该环境:
conda activate mamba
(3)安装torch(建议2.3.1版本)以及相应的 torchvison、torchaudio:
方法一:官网指令安装(不推荐!)
conda install cudatoolkit==11.8 -c nvidia
pip install torch==2.3.1 torchvision==0.18.1 torchaudio==2.3.1 --index-url https://download.pytorch.org/whl/cu118
这种方法在国内下载速度巨慢,而且可能时间久了指令会失效,装了其他错乱版本,这是我们不想看到的。所以针对pytorch的安装,极力推荐下面的离线安装方式,需要对应哪个python、cuda版本的torch一目了然。
方法二:离线安装(推荐!!!!!!!!!!!!)
我们可以直接进入pytorch离线包下载网址,在里面寻找对应的pytorch以及torchvison、torchaudio。
为了方便大家下载,我已经将下面需要的torch以及mamba相关库都下载打包好了,百度云盘连接:
例如我们上面列出的配置列表对应的文件如下*(注意cu118为cuda版本、cp310为python版本)。
torch-2.3.1+cu118-cp310-cp310-linux_x86_64.whl
torchvision-0.18.1+cu118-cp310-cp310-linux_x86_64.whl
torchaudio-2.3.1+cu118-cp310-cp310-linux_x86_64.whl
下载完成后,进入这些文件的目录下,直接使用下面三个指令进行安装即可:
pip install torch-2.3.1+cu118-cp310-cp310-linux_x86_64.whl
pip install torchvision-0.18.1+cu118-cp310-cp310-linux_x86_64.whl
pip install torchaudio-2.3.1+cu118-cp310-cp310-linux_x86_64.whl
(4)安装完pytorch之后,接着安装triton和transformers库:
pip install triton==2.3.1
pip install transformers==4.43.3
(5)安装完这些我们最基本Pytorch环境以及配置完成,接下来就是Mamba所需的一些依赖了,由于Mamba需要底层的C++进行编译,所以还需要手动安装一下cuda-nvcc这个库,直接使用conda命令即可:
conda install -c "nvidia/label/cuda-11.8.0" cuda-nvcc
(6)最后就是下载最重要的 causal-conv1d 和mamba-ssm库。在这里我们同样选择离线安装的方式,来避免大量奇葩的编译bug。首先进入下面各自的github网址种进行下载对应版本(这两个库也打包在上面提供的百度网盘链接中):
和安装pytorch一样,进入下载的.whl文件所在文件夹,直接使用以下指令进行安装:
pip install causal_conv1d-1.4.0+cu118torch2.3cxx11abiFALSE-cp310-cp310-linux_x86_64.whl
pip install mamba_ssm-2.2.2+cu118torch2.3cxx11abiFALSE-cp310-cp310-linux_x86_64.whl
到此,所有环境配置已经完成!
四、史上最简单的环境配置方法——傻瓜式一键安装conda环境
如果你觉得上面介绍的方法还不够简单,可以参考这个博客:
同样我也将我的Conda环境进行了打包(包含了Yolov8),大家直接从下面百度网盘链接下载,然后按照博客内容进行安装即可!非常简单!!!
链接: https://pan.baidu.com/s/1pgbk4Pz5Bb6ijSMLJ8vQxg?pwd=bxf9 提取码: bxf9
–来自百度网盘超级会员v5的分享
五、MambaV2代码验证
安装好环境后,我们接着来验证一下Mamba块能否成功运行,直接复制下面代码保存问mamba2_test.py,并运行:
# Copyright (c) 2024, Tri Dao, Albert Gu.
import math
import torch
import torch.nn as nn
import torch.nn.functional as F
from einops import rearrange, repeat
try:
from causal_conv1d import causal_conv1d_fn
except ImportError:
causal_conv1d_fn = None
try:
from mamba_ssm.ops.triton.layernorm_gated import RMSNorm as RMSNormGated, LayerNorm
except ImportError:
RMSNormGated, LayerNorm = None, None
from mamba_ssm.ops.triton.ssd_combined import mamba_chunk_scan_combined
from mamba_ssm.ops.triton.ssd_combined import mamba_split_conv1d_scan_combined
class Mamba2Simple(nn.Module):
def __init__(
self,
d_model,
d_state=128,
d_conv=4,
conv_init=None,
expand=2,
headdim=64,
ngroups=1,
A_init_range=(1, 16),
dt_min=0.001,
dt_max=0.1,
dt_init_floor=1e-4,
dt_limit=(0.0, float("inf")),
learnable_init_states=False,
activation="swish",
bias=False,
conv_bias=True,
# Fused kernel and sharding options
chunk_size=256,
use_mem_eff_path=True,
layer_idx=None, # Absorb kwarg for general module
device=None,
dtype=None,
):
factory_kwargs = {"device": device, "dtype": dtype}
super().__init__()
self.d_model = d_model
self.d_state = d_state
self.d_conv = d_conv
self.conv_init = conv_init
self.expand = expand
self.d_inner = self.expand * self.d_model
self.headdim = headdim
self.ngroups = ngroups
assert self.d_inner % self.headdim == 0
self.nheads = self.d_inner // self.headdim
self.dt_limit = dt_limit
self.learnable_init_states = learnable_init_states
self.activation = activation
self.chunk_size = chunk_size
self.use_mem_eff_path = use_mem_eff_path
self.layer_idx = layer_idx
# Order: [z, x, B, C, dt]
d_in_proj = 2 * self.d_inner + 2 * self.ngroups * self.d_state + self.nheads
self.in_proj = nn.Linear(self.d_model, d_in_proj, bias=bias, **factory_kwargs)
conv_dim = self.d_inner + 2 * self.ngroups * self.d_state
self.conv1d = nn.Conv1d(
in_channels=conv_dim,
out_channels=conv_dim,
bias=conv_bias,
kernel_size=d_conv,
groups=conv_dim,
padding=d_conv - 1,
**factory_kwargs,
)
if self.conv_init is not None:
nn.init.uniform_(self.conv1d.weight, -self.conv_init, self.conv_init)
# self.conv1d.weight._no_weight_decay = True
if self.learnable_init_states:
self.init_states = nn.Parameter(torch.zeros(self.nheads, self.headdim, self.d_state, **factory_kwargs))
self.init_states._no_weight_decay = True
self.act = nn.SiLU()
# Initialize log dt bias
dt = torch.exp(
torch.rand(self.nheads, **factory_kwargs) * (math.log(dt_max) - math.log(dt_min))
+ math.log(dt_min)
)
dt = torch.clamp(dt, min=dt_init_floor)
# Inverse of softplus: https://github.com/pytorch/pytorch/issues/72759
inv_dt = dt + torch.log(-torch.expm1(-dt))
self.dt_bias = nn.Parameter(inv_dt)
# Just to be explicit. Without this we already don't put wd on dt_bias because of the check
# name.endswith("bias") in param_grouping.py
self.dt_bias._no_weight_decay = True
# A parameter
assert A_init_range[0] > 0 and A_init_range[1] >= A_init_range[0]
A = torch.empty(self.nheads, dtype=torch.float32, device=device).uniform_(*A_init_range)
A_log = torch.log(A).to(dtype=dtype)
self.A_log = nn.Parameter(A_log)
# self.register_buffer("A_log", torch.zeros(self.nheads, dtype=torch.float32, device=device), persistent=True)
self.A_log._no_weight_decay = True
# D "skip" parameter
self.D = nn.Parameter(torch.ones(self.nheads, device=device))
self.D._no_weight_decay = True
# Extra normalization layer right before output projection
assert RMSNormGated is not None
self.norm = RMSNormGated(self.d_inner, eps=1e-5, norm_before_gate=False, **factory_kwargs)
self.out_proj = nn.Linear(self.d_inner, self.d_model, bias=bias, **factory_kwargs)
def forward(self, u, seq_idx=None):
"""
u: (B, L, D)
Returns: same shape as u
"""
batch, seqlen, dim = u.shape
zxbcdt = self.in_proj(u) # (B, L, d_in_proj)
A = -torch.exp(self.A_log) # (nheads) or (d_inner, d_state)
initial_states=repeat(self.init_states, "... -> b ...", b=batch) if self.learnable_init_states else None
dt_limit_kwargs = {} if self.dt_limit == (0.0, float("inf")) else dict(dt_limit=self.dt_limit)
if self.use_mem_eff_path:
# Fully fused path
out = mamba_split_conv1d_scan_combined(
zxbcdt,
rearrange(self.conv1d.weight, "d 1 w -> d w"),
self.conv1d.bias,
self.dt_bias,
A,
D=self.D,
chunk_size=self.chunk_size,
seq_idx=seq_idx,
activation=self.activation,
rmsnorm_weight=self.norm.weight,
rmsnorm_eps=self.norm.eps,
outproj_weight=self.out_proj.weight,
outproj_bias=self.out_proj.bias,
headdim=self.headdim,
ngroups=self.ngroups,
norm_before_gate=False,
initial_states=initial_states,
**dt_limit_kwargs,
)
else:
z, xBC, dt = torch.split(
zxbcdt, [self.d_inner, self.d_inner + 2 * self.ngroups * self.d_state, self.nheads], dim=-1
)
dt = F.softplus(dt + self.dt_bias) # (B, L, nheads)
assert self.activation in ["silu", "swish"]
# 1D Convolution
if causal_conv1d_fn is None or self.activation not in ["silu", "swish"]:
xBC = self.act(
self.conv1d(xBC.transpose(1, 2)).transpose(1, 2)
) # (B, L, self.d_inner + 2 * ngroups * d_state)
xBC = xBC[:, :seqlen, :]
else:
xBC = causal_conv1d_fn(
x=xBC.transpose(1, 2),
weight=rearrange(self.conv1d.weight, "d 1 w -> d w"),
bias=self.conv1d.bias,
activation=self.activation,
).transpose(1, 2)
# Split into 3 main branches: X, B, C
# These correspond to V, K, Q respectively in the SSM/attention duality
x, B, C = torch.split(xBC, [self.d_inner, self.ngroups * self.d_state, self.ngroups * self.d_state], dim=-1)
y = mamba_chunk_scan_combined(
rearrange(x, "b l (h p) -> b l h p", p=self.headdim),
dt,
A,
rearrange(B, "b l (g n) -> b l g n", g=self.ngroups),
rearrange(C, "b l (g n) -> b l g n", g=self.ngroups),
chunk_size=self.chunk_size,
D=self.D,
z=None,
seq_idx=seq_idx,
initial_states=initial_states,
**dt_limit_kwargs,
)
y = rearrange(y, "b l h p -> b l (h p)")
# Multiply "gate" branch and apply extra normalization layer
y = self.norm(y, z)
out = self.out_proj(y)
return out
if __name__ == '__main__':
model = Mamba2Simple(256).cuda()
inputs = torch.randn(2, 128, 256).cuda()
pred = model(inputs)
print(pred.size())
跑一次可能需要几分钟,耐心等待,最终能输出结果即配置成功!MambaV1也是类似的方法。
六、可能遇到的报错
报错1: assert RMSNormGated is not None AssertionErrorl
这是可能是因为transformers库版本过低,导致RMSNormGated 无法导入,升级成4.43.3版本即可!
报错2:RuntimeError: causal_conv1d with channel last layout requires strides (x.stride(0) and x.stride(2)) to be multiples of 8
在代码中,要确保 d_model * expand / headdim = multiple of 8