windows系统下anaconda中配置Mamba官方代码环境

#项目场景#

最近Mamba有关的论文引起了众多人的关注,笔者知道很多人开始复现官方代码,但是由于官方代码虚拟环境创建在ubuntu系统下,因此在windows系统下复现代码遇到各种各样的问题。笔者在经过一晚上的尝试,总算在win11系统anaconda中成功配置了环境。

问题描述 

  1. Building wheel for causal-conv1d (setup.py) ... error
  2. ERROR: Could not build wheels for causal-conv1d, which is required to install pyproject.toml-based projects
  3. ERROR: Could not build wheels for mamba-ssm, which is required to install pyproject.toml-based projects

主要原因是CUDA版本不兼容。

解决方案 

本次复现主要参考这篇博客方法三,并根据自己电脑情况修改了一些步骤。

conda create -n your_env_name python=3.10.13
conda activate your_env_name
conda install cudatoolkit==11.8 -c nvidia
pip install torch==2.1.1 torchvision==0.16.1 torchaudio==2.1.1 --index-url https://download.pytorch.org/whl/cu118
conda install -c "nvidia/label/cuda-11.8.0" cuda-nvcc
conda install packaging

接下来安装 'triton’包:tritan包安装教程,有大神编译了Windows下二进制文件,下载到本地后,在anacoda终端中,切换到tritan所在文件夹,输入

pip install triton-2.0.0-cp310-cp310-win_amd64.whl

然后使用源码编译安装causal-conv1d,注意在此之前,请检查torch的cuda版本和你自己的cuda版本是否一致。并输入

where nvcc

检查CUDA版本(最好在此电脑-属性-系统-高级系统设置-环境变量)中检查新安装的CUDA11.8是否添加(在这里卡了很久,也不知道是什么原因,where nvcc指令只能找到自己电脑的旧版本CUDA,笔者干脆手动安装了CUDA11.8和CUDNN安装教程),安装完成后,重启电脑,重新进入anaconda环境。接下来进行causal-conv1d安装,笔者采用源码编译的方式。首先请在causal-conv1d安装链接下载好对应版本的安装包(笔者下载的是1.0.0版本)。下载到本地后,解压,anaconda激活环境后进入该文件夹。输入

pip install .

在这里可能会出现

有时候缓存文件可能会导致安装出错。你可以尝试清理 pip 或 conda 的缓存

pip cache purge

然后再输入pip install .就可以啦。

接下来是mamba源码编译,请在mamba官方代码中setup.py文件修改配置

FORCE_BUILD = os.getenv("MAMBA_FORCE_BUILD", "FALSE") == "FALSE"
SKIP_CUDA_BUILD = os.getenv("MAMBA_SKIP_CUDA_BUILD", "FALSE") == "FALSE"
  • 此时,可以编译完成,但是无法将 selective_scan_cuda 包括进去,导入模块还是会出错。请在mamba_ssm/ops/selective_scan_interface.py该文件中注释掉:
    import selective_scan_cuda
    

def selective_scan_fn(u, delta, A, B, C, D=None, z=None, delta_bias=None, delta_softplus=False,
                     return_last_state=False):
    """if return_last_state is True, returns (out, last_state)
    last_state has shape (batch, dim, dstate). Note that the gradient of the last state is
    not considered in the backward pass.
    """
    return SelectiveScanFn.apply(u, delta, A, B, C, D, z, delta_bias, delta_softplus, return_last_state)


def mamba_inner_fn(
    xz, conv1d_weight, conv1d_bias, x_proj_weight, delta_proj_weight,
    out_proj_weight, out_proj_bias,
    A, B=None, C=None, D=None, delta_bias=None, B_proj_bias=None,
    C_proj_bias=None, delta_softplus=True
):
    return MambaInnerFn.apply(xz, conv1d_weight, conv1d_bias, x_proj_weight, delta_proj_weight,
                              out_proj_weight, out_proj_bias,
                              A, B, C, D, delta_bias, B_proj_bias, C_proj_bias, delta_softplus)

 改为

def selective_scan_fn(u, delta, A, B, C, D=None, z=None, delta_bias=None, delta_softplus=False,
                     return_last_state=False):
    """if return_last_state is True, returns (out, last_state)
    last_state has shape (batch, dim, dstate). Note that the gradient of the last state is
    not considered in the backward pass.
    """
    return selective_scan_ref(u, delta, A, B, C, D, z, delta_bias, delta_softplus, return_last_state)

def mamba_inner_fn(
    xz, conv1d_weight, conv1d_bias, x_proj_weight, delta_proj_weight,
    out_proj_weight, out_proj_bias,
    A, B=None, C=None, D=None, delta_bias=None, B_proj_bias=None,
    C_proj_bias=None, delta_softplus=True
):
    return mamba_inner_ref(xz, conv1d_weight, conv1d_bias, x_proj_weight, delta_proj_weight,
                              out_proj_weight, out_proj_bias,
                              A, B, C, D, delta_bias, B_proj_bias, C_proj_bias, delta_softplus)

最后就可以正常做mamba框架的相关实验啦~

最后放一下Mamba模型的简要介绍:

Mamba模型:选择性状态空间的新架构

Mamba是一种新颖的序列建模架构,被誉为Transformer的潜在竞争对手。相比于传统的模型,Mamba引入了选择性状态空间的概念,以更高效和有效地捕获相关信息。以下是Mamba的关键特点:

  1. 线性时间复杂度:与Transformer不同,Mamba在序列长度方面以线性时间运行,适用于处理非常长的序列任务。
  2. 灵活性和效率:Mamba结合了传统状态空间模型和循环神经网络的优点,具有高效计算和灵活性。
  3. 硬件感知算法:Mamba使用一种硬件感知算法,通过扫描操作而不是卷积,在GPU上高效地执行计算。

Mamba的核心组成包括固定主干和输入相关转换。在训练期间,它类似于Transformer,同时处理整个序列;在推理中,它更符合传统的循环模型,提供有效的序列处理。

此外,Mamba还使用了SRAM(Static Random-Access Memory)来优化内存需求,使其成为处理长序列的有前途的模型。

笔者研一菜鸡一枚~在安装的过程中conda下载完cuda11.8后又手动下载一遍(真的没搞懂),希望评论区有大神知道其中原理的欢迎交流~

  • 35
    点赞
  • 83
    收藏
    觉得还不错? 一键收藏
  • 42
    评论
评论 42
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值