一.安装Causal-Conv1d
1.首先安装packaging
conda install packaging
2.安装
git clone https://github.com/Dao-AILab/causal-conv1d.git
cd causal-conv1d
git checkout v1.1.1
3.进行install causal-conv1d
CAUSAL_CONV1D_FORCE_BUILD=TRUE pip install .
但有时候行不通,我们需要在setup.py里 大约37行的以下代码
FORCE_BUILD = os.getenv("CAUSAL_CONV1D_FORCE_BUILD", "FALSE").upper() == "TRUE"
SKIP_CUDA_BUILD = os.getenv("CAUSAL_CONV1D_SKIP_CUDA_BUILD", "FALSE") == "TRUE"
# For CI, we want the option to build with C++11 ABI since the nvcr images use C++11 ABI
FORCE_CXX11_ABI = os.getenv("CAUSAL_CONV1D_FORCE_CXX11_ABI", "FALSE") == "TRUE"
改为
FORCE_BUILD = os.getenv("CAUSAL_CONV1D_FORCE_BUILD", "TRUE").upper() == "TRUE"
SKIP_CUDA_BUILD = os.getenv("CAUSAL_CONV1D_SKIP_CUDA_BUILD", "False") == "TRUE"
# For CI, we want the option to build with C++11 ABI since the nvcr images use C++11 ABI
FORCE_CXX11_ABI = os.getenv("CAUSAL_CONV1D_FORCE_CXX11_ABI", "False") == "TRUE"
然后,再用 pip install . 安装就可以了
二、安装mamba-ssm
1.克隆
git clone https://github.com/state-spaces/mamba.git
cd mamba
git checkout v1.1.1
2.在mamba源码的 setup.py 修改配置
FORCE_BUILD = os.getenv("MAMBA_FORCE_BUILD", "FALSE") == "FALSE"
SKIP_CUDA_BUILD = os.getenv("MAMBA_SKIP_CUDA_BUILD", "FALSE") == "FALSE"
3.在mamba_ssm/ops/selective_scan_interface.py 进行修改
将以下代码
def selective_scan_fn(u, delta, A, B, C, D=None, z=None, delta_bias=None, delta_softplus=False,
return_last_state=False):
"""if return_last_state is True, returns (out, last_state)
last_state has shape (batch, dim, dstate). Note that the gradient of the last state is
not considered in the backward pass.
"""
return SelectiveScanFn.apply(u, delta, A, B, C, D, z, delta_bias, delta_softplus, return_last_state)
def mamba_inner_fn(
xz, conv1d_weight, conv1d_bias, x_proj_weight, delta_proj_weight,
out_proj_weight, out_proj_bias,
A, B=None, C=None, D=None, delta_bias=None, B_proj_bias=None,
C_proj_bias=None, delta_softplus=True
):
return MambaInnerFn.apply(xz, conv1d_weight, conv1d_bias, x_proj_weight, delta_proj_weight,
out_proj_weight, out_proj_bias,
A, B, C, D, delta_bias, B_proj_bias, C_proj_bias, delta_softplus)
改为以下代码
def selective_scan_fn(u, delta, A, B, C, D=None, z=None, delta_bias=None, delta_softplus=False,
return_last_state=False):
"""if return_last_state is True, returns (out, last_state)
last_state has shape (batch, dim, dstate). Note that the gradient of the last state is
not considered in the backward pass.
"""
return selective_scan_ref(u, delta, A, B, C, D, z, delta_bias, delta_softplus, return_last_state)
def mamba_inner_fn(
xz, conv1d_weight, conv1d_bias, x_proj_weight, delta_proj_weight,
out_proj_weight, out_proj_bias,
A, B=None, C=None, D=None, delta_bias=None, B_proj_bias=None,
C_proj_bias=None, delta_softplus=True
):
return mamba_inner_ref(xz, conv1d_weight, conv1d_bias, x_proj_weight, delta_proj_weight,
out_proj_weight, out_proj_bias,
A, B, C, D, delta_bias, B_proj_bias, C_proj_bias, delta_softplus)
4.然后直接进入mamba文件夹里安装
pip install .
三、前面是windows下的安装mamba-ssm方法 ,如果是linux下就方便多了,到下边网址里找到你要的版本的whl的文件下载
https://github.com/state-spaces/mamba/releases