Causal-Conv1d 和 Mamba-ssm 安装方法

一.安装Causal-Conv1d

1.首先安装packaging

conda install packaging

2.安装

git clone https://github.com/Dao-AILab/causal-conv1d.git
cd causal-conv1d
git checkout v1.1.1

3.进行install causal-conv1d

CAUSAL_CONV1D_FORCE_BUILD=TRUE pip install .

 但有时候行不通,我们需要在setup.py里 大约37行的以下代码

FORCE_BUILD = os.getenv("CAUSAL_CONV1D_FORCE_BUILD", "FALSE").upper() == "TRUE"
SKIP_CUDA_BUILD = os.getenv("CAUSAL_CONV1D_SKIP_CUDA_BUILD", "FALSE") == "TRUE"
# For CI, we want the option to build with C++11 ABI since the nvcr images use C++11 ABI
FORCE_CXX11_ABI = os.getenv("CAUSAL_CONV1D_FORCE_CXX11_ABI", "FALSE") == "TRUE"

 改为

FORCE_BUILD = os.getenv("CAUSAL_CONV1D_FORCE_BUILD", "TRUE").upper() == "TRUE"
SKIP_CUDA_BUILD = os.getenv("CAUSAL_CONV1D_SKIP_CUDA_BUILD", "False") == "TRUE"
# For CI, we want the option to build with C++11 ABI since the nvcr images use C++11 ABI
FORCE_CXX11_ABI = os.getenv("CAUSAL_CONV1D_FORCE_CXX11_ABI", "False") == "TRUE"

 然后,再用 pip install . 安装就可以了

二、安装mamba-ssm

1.克隆

git clone https://github.com/state-spaces/mamba.git
cd mamba
git checkout v1.1.1

2.在mamba源码的 setup.py 修改配置

FORCE_BUILD = os.getenv("MAMBA_FORCE_BUILD", "FALSE") == "FALSE"
SKIP_CUDA_BUILD = os.getenv("MAMBA_SKIP_CUDA_BUILD", "FALSE") == "FALSE"

 3.在mamba_ssm/ops/selective_scan_interface.py 进行修改

将以下代码

def selective_scan_fn(u, delta, A, B, C, D=None, z=None, delta_bias=None, delta_softplus=False,
                     return_last_state=False):
    """if return_last_state is True, returns (out, last_state)
    last_state has shape (batch, dim, dstate). Note that the gradient of the last state is
    not considered in the backward pass.
    """
    return SelectiveScanFn.apply(u, delta, A, B, C, D, z, delta_bias, delta_softplus, return_last_state)
 
 
def mamba_inner_fn(
    xz, conv1d_weight, conv1d_bias, x_proj_weight, delta_proj_weight,
    out_proj_weight, out_proj_bias,
    A, B=None, C=None, D=None, delta_bias=None, B_proj_bias=None,
    C_proj_bias=None, delta_softplus=True
):
    return MambaInnerFn.apply(xz, conv1d_weight, conv1d_bias, x_proj_weight, delta_proj_weight,
                              out_proj_weight, out_proj_bias,
                              A, B, C, D, delta_bias, B_proj_bias, C_proj_bias, delta_softplus)

 改为以下代码

def selective_scan_fn(u, delta, A, B, C, D=None, z=None, delta_bias=None, delta_softplus=False,
                     return_last_state=False):
    """if return_last_state is True, returns (out, last_state)
    last_state has shape (batch, dim, dstate). Note that the gradient of the last state is
    not considered in the backward pass.
    """
    return selective_scan_ref(u, delta, A, B, C, D, z, delta_bias, delta_softplus, return_last_state)
 
def mamba_inner_fn(
    xz, conv1d_weight, conv1d_bias, x_proj_weight, delta_proj_weight,
    out_proj_weight, out_proj_bias,
    A, B=None, C=None, D=None, delta_bias=None, B_proj_bias=None,
    C_proj_bias=None, delta_softplus=True
):
    return mamba_inner_ref(xz, conv1d_weight, conv1d_bias, x_proj_weight, delta_proj_weight,
                              out_proj_weight, out_proj_bias,
                              A, B, C, D, delta_bias, B_proj_bias, C_proj_bias, delta_softplus)
 

4.然后直接进入mamba文件夹里安装

pip install .

三、前面是windows下的安装mamba-ssm方法 ,如果是linux下就方便多了,到下边网址里找到你要的版本的whl的文件下载

https://github.com/state-spaces/mamba/releases

### Mamba与UNet的集成、安装及配置 #### 轻量级策略概述 Mamba是一种全新的轻量级策略,旨在通过优化网络结构来减少模型参数数量的同时保持甚至提升性能[^1]。它被设计用于改进UNet架构,在不显著增加计算成本的情况下增强了对全局信息的理解能力。 #### 安装依赖项 为了使用Mamba与UNet结合的功能,通常需要先安装必要的Python库其他工具包。以下是基于`mamba`命令行工具的一个典型安装过程: ```bash # 创建一个新的环境并激活 mamba create -n unet_mamba_env python=3.9 mamba activate unet_mamba_env # 安装PyTorch及相关依赖 mamba install pytorch torchvision torchaudio cudatoolkit=11.3 -c pytorch # 如果有特定版本需求,则可以手动指定unet相关模块 pip install git+https://github.com/your-repo/unet-mamba.git@v0.1 # 假设存在这样的仓库标签 ``` 上述脚本中的最后一行为假设情况下的操作;实际项目可能有不同的Git地址或者分发方式,请参照具体文档说明。 #### 配置文件调整 当集成了Mamba到UNet之后,还需要修改默认配置以适应新的混合架构特性。一般而言,这涉及更改训练超参以及定义自定义层的方式。下面是一个简单的例子展示如何初始化带有Mamba特性的UNet实例: ```python from custom_unet import UNetWithMamba model = UNetWithMamba( input_channels=3, output_classes=2, backbone="resnet18", # 或其他预定义选项 mamba_enabled=True, # 启用Mamba组件开关 ) print(model) ``` 此代码片段展示了创建一个启用了Mamba功能的新版UNet对象的过程。注意这里的`custom_unet`应替换为你所使用的框架的实际导入路径。 #### 性能对比分析 研究表明,采用Mamba增强后的UNet相比原始版本能够实现更少的参数规模(约减少了116倍),但在某些任务上的表现仍然优于后者。这种效率增益对于资源受限设备尤其重要。 ---
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

纬领网络

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值