Mamba: Linear-Time Sequence Modeling with Selective State Spaces
1. 作者介绍
2. 以往的模型
3. Mamba中使用的结构
4. 文章整体框架
a. State Space Models(SSM)
b. Sequences with Structured State Spaces(S4)
S4比较SSM有三点升级
- 离散化SSM
- 循环/卷积表示
- 基于HiPPO处理长序列
ⅰ. 离散数据的连续化:基于零阶保持技术做连续化并采样
参考大佬博客:博客地址
- 首先,每次收到离散信号时,我们都会保留其值,直到收到新的离散信号,如此操作导致的结果就是创建了 SSM 可以使用的连续信号
- 保持该值的时间由一个新的可学习参数表示,称为步长(siz)——
- 有了连续的输入信号后,便可以生成连续的输出,并且仅根据输入的时间步长对值进行采样
这些采样值就是我们的离散输出,且可以针对A、B按如下方式做零阶保持(做了零阶保持的在对应变量上面加了个横杠)
最终使我们能够从连续 SSM 转变为离散SSM,使得不再是函数到函数x(t) → y(t),而是序列到序列xₖ → yₖ,所以你看到,矩阵
和
现在表示模型的离散参数,且这里使用
,而不是
来表示离散的时间步长
ⅱ. 循环结构表示:方便快速推理
在每个时间步,都会涉及到隐藏状态的更新(比如
取决于
和
的共同作用结果,然后通过
预测输出
)
然后可以这样展开(其中,
始终是
和
的共同作用之下更新的)
ⅲ. 卷积结构表示:方便并行训练
在经典的图像识别任务中,我们用过滤器(即卷积核kernels)来导出聚合特征,而SSM也可以表示成卷积的形式
由于mamba处理的是文本而不是图像,因此我们需要一维视角
而用来表示这个“过滤器”的内核源自 SSM 公式
ⅳ. 长距离依赖问题的解决之道——HiPPO
HiPPO: Recurrent Memory with Optimal Polynomial Projections
c. Mamba的Motivation
Parallel Scan
d. Mamba的三大创新(有选择处理信息 + 硬件感知算法 + 更简单的SSM架构)
ⅰ. GPU memory hierarchy
ⅱ. Kernal fusion
ⅲ. Recomputation
e. Mamba的结构
vmamba(mamba)环境配置以及踩坑
环境配置现已更新至mamba2中
请查看:Mamba2 coming back-Transformers are SSMs-CSDN博客
注意:
- mamba暂时只支持linux系统(windows的ssm的包还没发,时间2024.4.22)
- mamba配置和vmamba之差一个包 ss2d的包(底下介绍vmamba安装的同时会介绍mamba)
- 安装时,需合理的科学上网
- 如想快速安装请下滑至最底下寻找完整步骤
vmamba仓库的地址:GitHub - MzeroMiko/VMamba: VMamba: Visual State Space Models,code is based on mamba
mamba仓库的地址:GitHub - state-spaces/mamba
步骤
1.克隆仓库
git clone GitHub - MzeroMiko/VMamba: VMamba: Visual State Space Models,code is based on mamba
cd VMamba
2.创建环境
conda create -n vmamba python=3.8
conda activate vmamba
3.安装依赖
pip install -r requirements.txt
cd kernels/selective_scan && pip install .#直接安装这一步会报错先装完第4步再安装
4.安装cuda(重要‼️不安装会报错)
我这里安装在环境中的
cuda版本为12.1(如果不知道自己环境中的cuda版本,可以conda list下查看nvidia-nvtx-cu12)
torch版本为2.2
python版本为3.8
最大的坑在于需要在外面也安装cuda版本需要与vmamba环境中的版本一致,因为在环境中装的cuda是不完整的(如果想要查看是否装有cuda,请输入指令nvcc -V),
- 安装前需要查看自己的显卡驱动(输入nvidia-smi),也可以去nvidia官网查看安装cuda所需要的最低驱动,如果低于cuda要求的最低版本参考博客
- 安装cuda上nvidia官网(需科学上网)
- 我这里安装的是cuda12.1
wget https://developer.download.nvidia.com/compute/cuda/12.1.0/local_installers/cuda_12.1.0_530.30.02_linux.run
sudo sh cuda_12.1.0_530.30.02_linux.run
- 最后需要更新cuda的路径
vim ~/.bashrc
#插入
export LD_LIBRARY_PATH=$LD_LIBRARY_PATH:/usr/local/cuda-12.1/lib64
export PATH=$PATH:/usr/local/cuda-12.1/bin
export CUDA_HOME=$CUDA_HOME:/usr/local/cuda-12.1
#保存退出
:wq!
#重启一下
source ~/.bashrc
5.安装到这一步,恭喜你跳过了一个大坑🎉🎉 安装ssm所需要的cuda环境,接下来就要安装
下面给出官方网址:Releases · state-spaces/mamba · GitHub
因为直接pip install mamba-ssm 会一直卡在setup那一步一直不动(因为这里的pip会直接调用github上面给的仓库,没法直接访问就一直卡着)所以就需要自己去上面那个github上找whl
对于vmamba而言:只需要安装mamba-ssmv1.2.0.post1这一个安装包就行了,所以这里给出我安装的whl所对应的网址:链接
对于mamba而言:需要安装mamba-ssm和causal-conv1d
完整的流程
conda create -n vmamba python=3.10
conda activate vmamba
conda install cudatoolkit==11.8 -c nvidia #这里使用11.8版本的cuda
pip install torch==2.1.1 torchvision==0.16.1 torchaudio==2.1.1 --index-url https://download.pytorch.org/whl/cu118
conda install -c "nvidia/label/cuda-11.8.0" cuda-nvcc
conda install packaging
下载完whl文件
pip install mamba_ssm-1.1.1+cu118torch2.1cxx11abiFALSE-cp310-cp310-linux_x86_64.whl
pip install causal_conv1d-1.1.3+cu118torch2.1cxx11abiFALSE-cp310-cp310-linux_x86_64.whl
参考:博客地址