windows Mamba安装

本文详细介绍了如何安装和配置Mamba框架,涉及CUDA11.8的下载与安装,以及如何在虚拟环境中安装相关库如Torch、Causal-Conv1d和mamba-ssm,同时强调了硬件感知算法和特定版本兼容性的重要性。
摘要由CSDN通过智能技术生成

1. Mamba优势

  • Selection mechanism(选择机制)引入了Gate,类似RNN的门控机制
  • Linearly in sequence length(线性计算),降低train和inference计算量
  • Hardware-aware Algorithm 降低硬件开销
  • Architecture was composed with H3 and Gated MLP

2. 安装

2.1 安装cuda

直接去Nvidia官网进行下载cu11.8
cu11.8:https://developer.nvidia.com/cuda-11-8-0-download-archive?target_os=Windows

最好是进行下载local文件进行本地安装,exe可能需要科学,安装完了之后可以直接进环境变量path看看是否有cu11.8(nvidia-ssm中的cuda version只是安装过版本问题,只需要查看nvcc -V查看是否为

2.2 基本环境

创建虚拟环境并且安装cudatoolkit==11.8,这样在虚拟变量中会额外有一个cuda版本(比较便捷),后面的cuda-nvcc一定要加,不然当前虚拟环境可能找不到对应的cuda版本

conda create -n your_env_name python=3.10.13
conda activate your_env_name
conda install cudatoolkit==11.8 -c nvidia
pip install torch==2.1.1 torchvision==0.16.1 torchaudio==2.1.1 --index-url https://download.pytorch.org/whl/cu118
conda install -c "nvidia/label/cuda-11.8.0" cuda-nvcc

2.3 安装Causal-Conv1d

安装Causal-Conv1d在checkout的时候一定要找对应cuda的版本

首先安装一下 packaging

conda install packaging

然后直接安装(亲测不行)

pip install causal-conv1d==1.1.1

源码安装
首先安装一下triton(里面包含了cmake)对后续编译提供基础
triton-2.0.0-cp310: https://hf-mirror.com/r4ziel/xformers_pre_built/blob/main/triton-2.0.0-cp310-cp310-win_amd64.whl

后续就是开始在github上面下载安装本地文件

git clone https://github.com/Dao-AILab/causal-conv1d.git
cd causal-conv1d
git checkout v1.1.1

然后就是进行install causal-conv1d的过程了

要在git(没有的话在git_download中进行下载)进行安装,在windows terminal中识别不到命令行

CAUSAL_CONV1D_FORCE_BUILD=TRUE pip install .

2.4 安装mamba-ssm

在github上面下载安装本地文件

git clone https://github.com/state-spaces/mamba.git
cd mamba
git checkout v1.1.1

先下载mamba的文件夹,然后进行文件安装

  • 在mamba源码的 setup.py 修改配置
FORCE_BUILD = os.getenv("MAMBA_FORCE_BUILD", "FALSE") == "FALSE"
SKIP_CUDA_BUILD = os.getenv("MAMBA_SKIP_CUDA_BUILD", "FALSE") == "FALSE"
  • 在ops/selective_scan_interface.py 文件下,注释掉下面的导入
import selective_scan_cuda

  • 在mamba_ssm/ops/selective_scan_interface.py 进行修改
    将以下代码
def selective_scan_fn(u, delta, A, B, C, D=None, z=None, delta_bias=None, delta_softplus=False,
                     return_last_state=False):
    """if return_last_state is True, returns (out, last_state)
    last_state has shape (batch, dim, dstate). Note that the gradient of the last state is
    not considered in the backward pass.
    """
    return SelectiveScanFn.apply(u, delta, A, B, C, D, z, delta_bias, delta_softplus, return_last_state)
 
 
def mamba_inner_fn(
    xz, conv1d_weight, conv1d_bias, x_proj_weight, delta_proj_weight,
    out_proj_weight, out_proj_bias,
    A, B=None, C=None, D=None, delta_bias=None, B_proj_bias=None,
    C_proj_bias=None, delta_softplus=True
):
    return MambaInnerFn.apply(xz, conv1d_weight, conv1d_bias, x_proj_weight, delta_proj_weight,
                              out_proj_weight, out_proj_bias,
                              A, B, C, D, delta_bias, B_proj_bias, C_proj_bias, delta_softplus)

改为

def selective_scan_fn(u, delta, A, B, C, D=None, z=None, delta_bias=None, delta_softplus=False,
                     return_last_state=False):
    """if return_last_state is True, returns (out, last_state)
    last_state has shape (batch, dim, dstate). Note that the gradient of the last state is
    not considered in the backward pass.
    """
    return selective_scan_ref(u, delta, A, B, C, D, z, delta_bias, delta_softplus, return_last_state)
 
def mamba_inner_fn(
    xz, conv1d_weight, conv1d_bias, x_proj_weight, delta_proj_weight,
    out_proj_weight, out_proj_bias,
    A, B=None, C=None, D=None, delta_bias=None, B_proj_bias=None,
    C_proj_bias=None, delta_softplus=True
):
    return mamba_inner_ref(xz, conv1d_weight, conv1d_bias, x_proj_weight, delta_proj_weight,
                              out_proj_weight, out_proj_bias,
                              A, B, C, D, delta_bias, B_proj_bias, C_proj_bias, delta_softplus)
 

然后直接mamba文件下进行安装即可

pip install .
Package             Version
------------------- ------------
attrs               23.2.0
Automat             22.10.0
buildtools          1.0.6
causal-conv1d       1.1.1
certifi             2022.12.7
charset-normalizer  2.1.1
cmake               3.29.2
colorama            0.4.6
constantly          23.10.4
docopt              0.6.2
einops              0.7.0
filelock            3.9.0
fsspec              2024.3.1
furl                2.1.3
greenlet            3.0.3
huggingface-hub     0.22.2
hyperlink           21.0.0
idna                3.4
incremental         22.10.0
Jinja2              3.1.2
mamba-ssm           1.1.1
MarkupSafe          2.1.3
mpmath              1.3.0
networkx            3.2.1
ninja               1.11.1.1
numpy               1.26.3
orderedmultidict    1.0.1
packaging           23.2
pillow              10.2.0
pip                 23.3.1
python-dateutil     2.9.0.post0
PyYAML              6.0.1
redo                2.0.4
regex               2024.4.16
requests            2.28.1
safetensors         0.4.3
setuptools          68.2.2
simplejson          3.19.2
six                 1.16.0
SQLAlchemy          2.0.29
sympy               1.12
tokenizers          0.19.1
torch               2.1.1+cu118
torchaudio          2.1.1+cu118
torchvision         0.16.1+cu118
tqdm                4.66.2
transformers        4.40.0
triton              2.0.0
Twisted             24.3.0
twisted-iocpsupport 1.0.4
typing_extensions   4.8.0
urllib3             1.26.13
wheel               0.41.2
zope.interface      6.3

  • 5
    点赞
  • 18
    收藏
    觉得还不错? 一键收藏
  • 打赏
    打赏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

heromps

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值