Mamba 环境安装踩坑问题汇总及解决方法(Windows已解决)

导航

项目场景:

最近Mamba有关的论文引起了众多人的关注,虽然Mamba论文自身被ICLR 2024拒稿,但是其衍生的模型层出不穷,诸如 VimUmamba 等。笔者在配置相关环境(版本安装要求:PyTorch 1.12+CUDA 11.6+)时,发现按照他们给的安装方法12安装时会遇到非常多的bug,主要集中在causal-conv1dmamba-ssm上,原因都是版本兼容问题,特此记录。
P.S. 经过和网友的深入讨论,本文内容已经有大幅扩充,大家如果有遇到新的问题及解决方法,希望大家可以告诉我,经我们共同验证有效的bug解决方法将会及时更新进本文!
安装问题 / 资源自取 / 论文合作想法请+vx:931744281

问题描述

直接 pip 安装或者下载工程文件再setup,出现了以下报错但不限于:

  1. Building wheel for causal-conv1d (setup.py) ... error
  2. error: command '/usr/bin/gcc' failed with exit code 1
  3. RuntimeError: Error compiling objects for extension
  4. ERROR: Could not build wheels for causal-conv1d, which is required to install pyproject.toml-based projects
  5. Connection timed out> [end of output]
  6. ModuleNotFoundError: No module named 'packaging'
  7. FileNotFoundError: [Errno 2] No such file or directory: '/usr/local/cuda/bin/nvcc'
  8. error: subprocess-exited-with-error

原因分析:

大部分原因是CUDA版本不匹配,有部分是网络原因。


解决方案(Linux):

  1. 使用网友配置好的Docker环境,参考:解决causal_conv1d和mamba_ssm无法安装 -> 直接使用Mamba基础环境docker镜像
    DockHub仓库地址:https://hub.docker.com/repository/docker/kom4cr0/cuda11.7-pytorch1.13-mamba1.1.1/general
    代码:docker pull kom4cr0/cuda11.7-pytorch1.13-mamba1.1.1:1.1.1

  2. 直接下载工程文件,再setup。具体可参考:运行Mamba项目时无法直接用pip install安装causal_conv1d和mamba_ssm复现U-Mamba笔者依然未安装成功,但是原作者以及GitHub issue 里有部分人可以安装成功
    参考步骤为:

    git clone https://github.com/Dao-AILab/causal-conv1d.git
    cd causal-conv1d
    git checkout v1.1.1 # current latest version tag
    CAUSAL_CONV1D_FORCE_BUILD=TRUE pip install .
    cd ..
    git clone https://github.com/state-spaces/mamba.git
    cd mamba
    git checkout v1.1.1 # current latest version tag
    pip install . # 方式一,下载whl安装,两种方式选择一个即可
    MAMBA_FORCE_BUILD=TRUE pip install . # 方式二,强制在本地编译安装,Win 下无法识别此命令
    
  3. 受博文 “flash-attention踩坑:使用conda管理CUDA”启发,合理调整安装顺序,先安装CUDA,并且安装cuda-nvcc,正确的安装步骤如下:

    conda create -n your_env_name python=3.10.13
    conda activate your_env_name
    conda install cudatoolkit==11.8 -c nvidia
    pip install torch==2.1.1 torchvision==0.16.1 torchaudio==2.1.1 --index-url https://download.pytorch.org/whl/cu118
    conda install -c "nvidia/label/cuda-11.8.0" cuda-nvcc
    conda install packaging
    pip install causal-conv1d==1.1.1  # 版本号根据实际情况选择,或者不指定直接安装最新
    pip install mamba-ssm==1.1.3.post1  # 版本号根据实际情况选择,1.1 和 1.2 实测有函数不兼容,不设定默认装最新版本
    

解决方案(Win):

  1. 采用 Docker环境;
  2. 按照 20240313更新 中步骤的修改代码,注意由于跳过了核心部分的CUDA加速,虽然可以跑通,但是速度很慢;
  3. 采用 20240329更新,利用 Win 中的 WSL,或者Linux虚拟机。

20240313更新

  1. 如果方法三中倒数第二步无法安装,则需要从项目源码编译。

  2. Windows 下安装mamba-ssm在方法三倒数第三步之后会不一样,即需要先安装 'triton’包,之后从causal-conv1d 以及mamba源码编译,并且修改源码

    FORCE_BUILD = os.getenv("MAMBA_FORCE_BUILD", "FALSE") == "FALSE"
    SKIP_CUDA_BUILD = os.getenv("MAMBA_SKIP_CUDA_BUILD", "FALSE") == "FALSE"
    
    • 或者在 pip install . 命令之前设置 set MAMBA_FORCE_BUILD=TRUE 以及 set MAMBA_SKIP_CUDA_BUILD=TRUE 。Linux 的命令 MAMBA_FORCE_BUILD=TRUE pip install . 在 Win 下会报错。
    • 此时,可以编译完成,但是无法将 selective_scan_cuda 包括进去,导入模块还是会出错。
    • 故编译前,需要修改源码,可参考:Windows Support #12
    • 即在ops/selective_scan_interface.py 文件下,注释掉
    import selective_scan_cuda
    

    def selective_scan_fn(u, delta, A, B, C, D=None, z=None, delta_bias=None, delta_softplus=False,
                         return_last_state=False):
        """if return_last_state is True, returns (out, last_state)
        last_state has shape (batch, dim, dstate). Note that the gradient of the last state is
        not considered in the backward pass.
        """
        return SelectiveScanFn.apply(u, delta, A, B, C, D, z, delta_bias, delta_softplus, return_last_state)
    
    
    def mamba_inner_fn(
        xz, conv1d_weight, conv1d_bias, x_proj_weight, delta_proj_weight,
        out_proj_weight, out_proj_bias,
        A, B=None, C=None, D=None, delta_bias=None, B_proj_bias=None,
        C_proj_bias=None, delta_softplus=True
    ):
        return MambaInnerFn.apply(xz, conv1d_weight, conv1d_bias, x_proj_weight, delta_proj_weight,
                                  out_proj_weight, out_proj_bias,
                                  A, B, C, D, delta_bias, B_proj_bias, C_proj_bias, delta_softplus)
    

    改为

    def selective_scan_fn(u, delta, A, B, C, D=None, z=None, delta_bias=None, delta_softplus=False,
                         return_last_state=False):
        """if return_last_state is True, returns (out, last_state)
        last_state has shape (batch, dim, dstate). Note that the gradient of the last state is
        not considered in the backward pass.
        """
        return selective_scan_ref(u, delta, A, B, C, D, z, delta_bias, delta_softplus, return_last_state)
    
    def mamba_inner_fn(
        xz, conv1d_weight, conv1d_bias, x_proj_weight, delta_proj_weight,
        out_proj_weight, out_proj_bias,
        A, B=None, C=None, D=None, delta_bias=None, B_proj_bias=None,
        C_proj_bias=None, delta_softplus=True
    ):
        return mamba_inner_ref(xz, conv1d_weight, conv1d_bias, x_proj_weight, delta_proj_weight,
                                  out_proj_weight, out_proj_bias,
                                  A, B, C, D, delta_bias, B_proj_bias, C_proj_bias, delta_softplus)
    
    

    当然,本人编译好的已经修改源码绕过 selective_scan_cuda 的Windows 下的whl 也有:mamba-ssm-1.1.3mamba_ssm-1.2.0.post1,可直接下载安装或联系本人vx自取。

20240329更新

Win 下Mamaba的安装除了利用docker、修改源码编译之外,也有人通过WSL成功跑通最新mamba模型,参考:
原生Windows通过WSL成功跑通最新mamba模型

20240418更新

1. 关于CDUA版本

不少小伙伴在装完 cuda-nvcc 以后,安装 causal-conv1d 时还是会显示CUDA版本不对的错误,这是由于环境中还可能有CUDA_HOME (Linux)或 CUDA_PATH (Windows)变量指定到错误的位置,此时需要检查:

nvcc -V
python -c "import torch.utils.cpp_extension; print(torch.utils.cpp_extension.CUDA_HOME)"

确保其输出的是正确的版本或位置。尤其是要保证第二句命令输出的位置是正确的。

在 Linux 下,如果第二句命令输出位置是base环境的,使用 which nvcc 获取虚拟环境正确的路径,然后在 .bashrc 里面设置成这个位置 export CUDA_HOME='....'source ~/.bashrc 激活配置,然后再继续安装过程。

在 Win 下,则使用 where nvcc 虚拟环境正确的路径(路径到bin,不包括 nvcc.exe),把系统环境变量里的 CUDA_PATH 修改为该路径,然后继续安装过程。

pytorch选择cuda的顺序可参考博文:pytorch选择cuda的顺序【关于cudatoolkit和/usr/local/cuda】

2. 关于 setup 之后卡住不动

在Linux下卡住不动是因为它在下载对应的 *.whl 文件,需要科学上网,可以等它下载失败输出正确的网址,然后手动下载再pip install 这个 whl 文件。可以直接下载whl安装
在我的配置下面:
causal_conv1d 下载链接为:https://github.com/Dao-AILab/causal-conv1d/releases/download/v1.1.1/causal_conv1d-1.1.1+cu118torch2.1cxx11abiFALSE-cp310-cp310-linux_x86_64.whl
mamba_ssm 下载链接为:https://github.com/state-spaces/mamba/releases/download/v1.1.3.post1/mamba_ssm-1.1.3.post1+cu118torch2.1cxx11abiFALSE-cp310-cp310-linux_x86_64.whl

20240424更新

1. 成功安装causal_conv1d_cudaselective_scan_cuda还是报错

不少小伙伴在成功安装 causal-conv1d 之后还是会出现 import causal_conv1d_cuda 提示没有名称为 causal_conv1d_cuda 的包,或者在成功安装 mamba-ssm 之后出现 import selective_scan_cuda 提示没有名称为 selective_scan_cuda 的包,这还是CUDA环境不兼容导致的。这两个函数对应着Python程序编译动态库(Linux 下为.so 文件,Windows下为.pyd文件),不在安装好后的源码中,而在 xxxx/envs/xxxx/lib/python3.xx/site-packages/ 下面,分别对应 causal_conv1d_cuda.cpython-310-x86_64-linux-gnu.so(以本人环境为例)和 selective_scan_cuda.cpython-310-x86_64-linux-gnu.so(以本人环境为例)。

此时建议用源码方式在本地强制编译安装(CAUSAL_CONV1D_FORCE_BUILD=TRUE pip install . 或者 MAMBA_FORCE_BUILD=TRUE pip install .),此时有的小伙伴会成功,有的小伙伴还是会报错,但是报错会给出具体信息,譬如 ImportError xxxx selective_scan_cuda.cpython-xxx-linux-gnu.so undefined symbol (可以用编译好的文件直接替换,selective-scan-cuda-linux-gnu.so)或者 ImportError xxxx causal_conv1d_cuda.cpython-xxx-linux-gnu.so undefined symbol (可以用编译好的文件直接替换,causal-conv1d-cuda.cpython-310-x86-64-linux-gnu.so),由于大家环境不一样,根据各自相应报错情况再针对性解决。【文件均可联系本人vx自取】

出现 .so undefined symbol 一般是因为 CUDA 版本不匹配造成的,参考本博客 20240418更新-关于CDUA版本。譬如在虚拟环境中 which nvcc 调用的是虚拟环境的 cuda,但是 python -c "import torch.utils.cpp_extension; print(torch.utils.cpp_extension.CUDA_HOME)" 输出的位置确是base 环境的 usr/local/cuda

此外,可以按照Win下的方法,修改源文件绕过对 causal_conv1d_cudaselective_scan_cuda 的调用。

2. Win 下 绕过对 causal_conv1d_cudaselective_scan_cuda 的调用

  • causal_conv1d_cuda:在 causal_conv1d_interface.py 文件中,注释掉 import causal_conv1d_cuda,且将

    def causal_conv1d_fn(
        x,
        weight,
        bias=None,
        seq_idx=None,
        initial_states=None,
        return_final_states=False,
        final_states_out=None,
        activation=None,
    ):
        """
        x: (batch, dim, seqlen)
        weight: (dim, width)
        bias: (dim,)
        seq_idx: (batch, seqlen)
        initial_states: (batch, dim, width - 1)
        final_states_out: (batch, dim, width - 1), to be written to
        activation: either None or "silu" or "swish"
    
        out: (batch, dim, seqlen)
        """
        return CausalConv1dFn.apply(
            x,
            weight,
            bias,
            seq_idx,
            initial_states,
            return_final_states,
            final_states_out,
            activation,
        )
    

    改为:

    def causal_conv1d_fn(
        x,
        weight,
        bias=None,
        seq_idx=None,
        initial_states=None,
        return_final_states=False,
        final_states_out=None,
        activation=None,
    ):
        """
        x: (batch, dim, seqlen)
        weight: (dim, width)
        bias: (dim,)
        seq_idx: (batch, seqlen)
        initial_states: (batch, dim, width - 1)
        final_states_out: (batch, dim, width - 1), to be written to
        activation: either None or "silu" or "swish"
    
        out: (batch, dim, seqlen)
        """
        return causal_conv1d_ref(
            x,
            weight,
            bias,
            seq_idx,
            initial_states,
            return_final_states,
            final_states_out,
            activation,
        )
    

    版本不同可能会有差异,但是都改这个函数。

  • causal_conv1d_cuda for Vim ( Vision Mamba: Efficient Visual Representation Learning with Bidirectional State Space Model ):

    def causal_conv1d_fn(x, weight, bias=None, activation=None):
        """
        x: (batch, dim, seqlen)
        weight: (dim, width)
        bias: (dim,)
        activation: either None or "silu" or "swish"
    
        out: (batch, dim, seqlen)
        """
        return CausalConv1dFn.apply(x, weight, bias, activation)
    

    改为:

    def causal_conv1d_fn(x, weight, bias=None, activation=None):
        """
        x: (batch, dim, seqlen)
        weight: (dim, width)
        bias: (dim,)
        activation: either None or "silu" or "swish"
    
        out: (batch, dim, seqlen)
        """
        return causal_conv1d_ref(x, weight, bias, activation)
    
  • selective_scan_cuda:见 20240313更新

  • selective_scan_cuda for Vim ( Vision Mamba: Efficient Visual Representation Learning with Bidirectional State Space Model ):由于它对源码做了修改,所以如果想在Window下跑通 Vision Mamba 为了避开这个函数也需要相应修改,如下所示。

    # Copyright (c) 2023, Tri Dao, Albert Gu.
    
    import torch
    import torch.nn.functional as F
    from torch.cuda.amp import custom_bwd, custom_fwd
    
    from einops import rearrange, repeat
    
    try:
        from causal_conv1d import causal_conv1d_fn
        import causal_conv1d_cuda
    except ImportError:
        causal_conv1d_fn = None
        causal_conv1d_cuda = None
    
    
    # import selective_scan_cuda
    
    
    class SelectiveScanFn(torch.autograd.Function):
    
        @staticmethod
        def forward(ctx, u, delta, A, B, C, D=None, z=None, delta_bias=None, delta_softplus=False,
                    return_last_state=False):
            if u.stride(-1) != 1:
                u = u.contiguous()
            if delta.stride(-1) != 1:
                delta = delta.contiguous()
            if D is not None:
                D = D.contiguous()
            if B.stride(-1) != 1:
                B = B.contiguous()
            if C.stride(-1) != 1:
                C = C.contiguous()
            if z is not None and z.stride(-1) != 1:
                z = z.contiguous()
            if B.dim() == 3:
                B = rearrange(B, "b dstate l -> b 1 dstate l")
                ctx.squeeze_B = True
            if C.dim() == 3:
                C = rearrange(C, "b dstate l -> b 1 dstate l")
                ctx.squeeze_C = True
            out, x, *rest = selective_scan_cuda.fwd(u, delta, A, B, C, D, z, delta_bias, delta_softplus)
            ctx.delta_softplus = delta_softplus
            ctx.has_z = z is not None
            last_state = x[:, :, -1, 1::2]  # (batch, dim, dstate)
            if not ctx.has_z:
                ctx.save_for_backward(u, delta, A, B, C, D, delta_bias, x)
                return out if not return_last_state else (out, last_state)
            else:
                ctx.save_for_backward(u, delta, A, B, C, D, z, delta_bias, x, out)
                out_z = rest[0]
                return out_z if not return_last_state else (out_z, last_state)
    
        @staticmethod
        def backward(ctx, dout, *args):
            if not ctx.has_z:
                u, delta, A, B, C, D, delta_bias, x = ctx.saved_tensors
                z = None
                out = None
            else:
                u, delta, A, B, C, D, z, delta_bias, x, out = ctx.saved_tensors
            if dout.stride(-1) != 1:
                dout = dout.contiguous()
            # The kernel supports passing in a pre-allocated dz (e.g., in case we want to fuse the
            # backward of selective_scan_cuda with the backward of chunk).
            # Here we just pass in None and dz will be allocated in the C++ code.
            du, ddelta, dA, dB, dC, dD, ddelta_bias, *rest = selective_scan_cuda.bwd(
                u, delta, A, B, C, D, z, delta_bias, dout, x, out, None, ctx.delta_softplus,
                False  # option to recompute out_z, not used here
            )
            dz = rest[0] if ctx.has_z else None
            dB = dB.squeeze(1) if getattr(ctx, "squeeze_B", False) else dB
            dC = dC.squeeze(1) if getattr(ctx, "squeeze_C", False) else dC
            return (du, ddelta, dA, dB, dC,
                    dD if D is not None else None,
                    dz,
                    ddelta_bias if delta_bias is not None else None,
                    None,
                    None)
    
    
    def selective_scan_fn(u, delta, A, B, C, D=None, z=None, delta_bias=None, delta_softplus=False,
                          return_last_state=False):
        """if return_last_state is True, returns (out, last_state)
        last_state has shape (batch, dim, dstate). Note that the gradient of the last state is
        not considered in the backward pass.
        """
        return selective_scan_ref(u, delta, A, B, C, D, z, delta_bias, delta_softplus, return_last_state)
    
    
    def selective_scan_ref(u, delta, A, B, C, D=None, z=None, delta_bias=None, delta_softplus=False,
                           return_last_state=False):
        """
        u: r(B D L)
        delta: r(B D L)
        A: c(D N) or r(D N)
        B: c(D N) or r(B N L) or r(B N 2L) or r(B G N L) or (B G N L)
        C: c(D N) or r(B N L) or r(B N 2L) or r(B G N L) or (B G N L)
        D: r(D)
        z: r(B D L)
        delta_bias: r(D), fp32
    
        out: r(B D L)
        last_state (optional): r(B D dstate) or c(B D dstate)
        """
        dtype_in = u.dtype
        u = u.float()
        delta = delta.float()
        if delta_bias is not None:
            delta = delta + delta_bias[..., None].float()
        if delta_softplus:
            delta = F.softplus(delta)
        batch, dim, dstate = u.shape[0], A.shape[0], A.shape[1]
        is_variable_B = B.dim() >= 3
        is_variable_C = C.dim() >= 3
        if A.is_complex():
            if is_variable_B:
                B = torch.view_as_complex(rearrange(B.float(), "... (L two) -> ... L two", two=2))
            if is_variable_C:
                C = torch.view_as_complex(rearrange(C.float(), "... (L two) -> ... L two", two=2))
        else:
            B = B.float()
            C = C.float()
        x = A.new_zeros((batch, dim, dstate))
        ys = []
        deltaA = torch.exp(torch.einsum('bdl,dn->bdln', delta, A))
        if not is_variable_B:
            deltaB_u = torch.einsum('bdl,dn,bdl->bdln', delta, B, u)
        else:
            if B.dim() == 3:
                deltaB_u = torch.einsum('bdl,bnl,bdl->bdln', delta, B, u)
            else:
                B = repeat(B, "B G N L -> B (G H) N L", H=dim // B.shape[1])
                deltaB_u = torch.einsum('bdl,bdnl,bdl->bdln', delta, B, u)
        if is_variable_C and C.dim() == 4:
            C = repeat(C, "B G N L -> B (G H) N L", H=dim // C.shape[1])
        last_state = None
        for i in range(u.shape[2]):
            x = deltaA[:, :, i] * x + deltaB_u[:, :, i]
            if not is_variable_C:
                y = torch.einsum('bdn,dn->bd', x, C)
            else:
                if C.dim() == 3:
                    y = torch.einsum('bdn,bn->bd', x, C[:, :, i])
                else:
                    y = torch.einsum('bdn,bdn->bd', x, C[:, :, :, i])
            if i == u.shape[2] - 1:
                last_state = x
            if y.is_complex():
                y = y.real * 2
            ys.append(y)
        y = torch.stack(ys, dim=2)  # (batch dim L)
        out = y if D is None else y + u * rearrange(D, "d -> d 1")
        if z is not None:
            out = out * F.silu(z)
        out = out.to(dtype=dtype_in)
        return out if not return_last_state else (out, last_state)
    
    
    class MambaInnerFnNoOutProj(torch.autograd.Function):
    
        @staticmethod
        @custom_fwd
        def forward(ctx, xz, conv1d_weight, conv1d_bias, x_proj_weight, delta_proj_weight,
                    A, B=None, C=None, D=None, delta_bias=None, B_proj_bias=None,
                    C_proj_bias=None, delta_softplus=True, checkpoint_lvl=1):
            """
                 xz: (batch, dim, seqlen)
            """
            assert checkpoint_lvl in [0, 1]
            L = xz.shape[-1]
            delta_rank = delta_proj_weight.shape[1]
            d_state = A.shape[-1] * (1 if not A.is_complex() else 2)
            if torch.is_autocast_enabled():
                x_proj_weight = x_proj_weight.to(dtype=torch.get_autocast_gpu_dtype())
                delta_proj_weight = delta_proj_weight.to(dtype=torch.get_autocast_gpu_dtype())
            if xz.stride(-1) != 1:
                xz = xz.contiguous()
            conv1d_weight = rearrange(conv1d_weight, "d 1 w -> d w")
            x, z = xz.chunk(2, dim=1)
            conv1d_bias = conv1d_bias.contiguous() if conv1d_bias is not None else None
            conv1d_out = causal_conv1d_cuda.causal_conv1d_fwd(x, conv1d_weight, conv1d_bias, None, True)
            # We're being very careful here about the layout, to avoid extra transposes.
            # We want delta to have d as the slowest moving dimension
            # and L as the fastest moving dimension, since those are what the ssm_scan kernel expects.
            x_dbl = F.linear(rearrange(conv1d_out, 'b d l -> (b l) d'), x_proj_weight)  # (bl d)
            delta = rearrange(delta_proj_weight @ x_dbl[:, :delta_rank].t(), "d (b l) -> b d l", l=L)
            ctx.is_variable_B = B is None
            ctx.is_variable_C = C is None
            ctx.B_proj_bias_is_None = B_proj_bias is None
            ctx.C_proj_bias_is_None = C_proj_bias is None
            if B is None:  # variable B
                B = x_dbl[:, delta_rank:delta_rank + d_state]  # (bl dstate)
                if B_proj_bias is not None:
                    B = B + B_proj_bias.to(dtype=B.dtype)
                if not A.is_complex():
                    # B = rearrange(B, "(b l) dstate -> b dstate l", l=L).contiguous()
                    B = rearrange(B, "(b l) dstate -> b 1 dstate l", l=L).contiguous()
                else:
                    B = rearrange(B, "(b l) (dstate two) -> b 1 dstate (l two)", l=L, two=2).contiguous()
            else:
                if B.stride(-1) != 1:
                    B = B.contiguous()
            if C is None:  # variable C
                C = x_dbl[:, -d_state:]  # (bl dstate)
                if C_proj_bias is not None:
                    C = C + C_proj_bias.to(dtype=C.dtype)
                if not A.is_complex():
                    # C = rearrange(C, "(b l) dstate -> b dstate l", l=L).contiguous()
                    C = rearrange(C, "(b l) dstate -> b 1 dstate l", l=L).contiguous()
                else:
                    C = rearrange(C, "(b l) (dstate two) -> b 1 dstate (l two)", l=L, two=2).contiguous()
            else:
                if C.stride(-1) != 1:
                    C = C.contiguous()
            if D is not None:
                D = D.contiguous()
            out, scan_intermediates, out_z = selective_scan_cuda.fwd(
                conv1d_out, delta, A, B, C, D, z, delta_bias, delta_softplus
            )
            ctx.delta_softplus = delta_softplus
            ctx.checkpoint_lvl = checkpoint_lvl
            if checkpoint_lvl >= 1:  # Will recompute conv1d_out and delta in the backward pass
                conv1d_out, delta = None, None
            ctx.save_for_backward(xz, conv1d_weight, conv1d_bias, x_dbl, x_proj_weight,
                                  delta_proj_weight, conv1d_out, delta,
                                  A, B, C, D, delta_bias, scan_intermediates, out)
            # return rearrange(out_z, "b d l -> b l d")
            return out_z
    
        @staticmethod
        @custom_bwd
        def backward(ctx, dout):
            # dout: (batch, seqlen, dim)
            (xz, conv1d_weight, conv1d_bias, x_dbl, x_proj_weight, delta_proj_weight,
             conv1d_out, delta, A, B, C, D, delta_bias, scan_intermediates, out) = ctx.saved_tensors
            L = xz.shape[-1]
            delta_rank = delta_proj_weight.shape[1]
            d_state = A.shape[-1] * (1 if not A.is_complex() else 2)
            x, z = xz.chunk(2, dim=1)
            if dout.stride(-1) != 1:
                dout = dout.contiguous()
            if ctx.checkpoint_lvl == 1:
                conv1d_out = causal_conv1d_cuda.causal_conv1d_fwd(x, conv1d_weight, conv1d_bias, None, True)
                delta = rearrange(delta_proj_weight @ x_dbl[:, :delta_rank].t(),
                                  "d (b l) -> b d l", l=L)
            # The kernel supports passing in a pre-allocated dz (e.g., in case we want to fuse the
            # backward of selective_scan_cuda with the backward of chunk).
            dxz = torch.empty_like(xz)  # (batch, dim, seqlen)
            dx, dz = dxz.chunk(2, dim=1)
            # dout_y = rearrange(dout, "b l d -> b d l") # because no arrange at end of forward, so dout shape is b d l
            dconv1d_out, ddelta, dA, dB, dC, dD, ddelta_bias, dz, out_z = selective_scan_cuda.bwd(
                conv1d_out, delta, A, B, C, D, z, delta_bias, dout, scan_intermediates, out, dz,
                ctx.delta_softplus,
                True  # option to recompute out_z
            )
            dD = dD if D is not None else None
            dx_dbl = torch.empty_like(x_dbl)
            dB_proj_bias = None
            if ctx.is_variable_B:
                if not A.is_complex():
                    dB = rearrange(dB, "b 1 dstate l -> (b l) dstate").contiguous()
                else:
                    dB = rearrange(dB, "b 1 dstate (l two) -> (b l) (dstate two)", two=2).contiguous()
                dB_proj_bias = dB.sum(0) if not ctx.B_proj_bias_is_None else None
                dx_dbl[:, delta_rank:delta_rank + d_state] = dB  # (bl d)
                dB = None
            dC_proj_bias = None
            if ctx.is_variable_C:
                if not A.is_complex():
                    dC = rearrange(dC, "b 1 dstate l -> (b l) dstate").contiguous()
                else:
                    dC = rearrange(dC, "b 1 dstate (l two) -> (b l) (dstate two)", two=2).contiguous()
                dC_proj_bias = dC.sum(0) if not ctx.C_proj_bias_is_None else None
                dx_dbl[:, -d_state:] = dC  # (bl d)
                dC = None
            ddelta = rearrange(ddelta, "b d l -> d (b l)")
            ddelta_proj_weight = torch.einsum("dB,Br->dr", ddelta, x_dbl[:, :delta_rank])
            dx_dbl[:, :delta_rank] = torch.einsum("dB,dr->Br", ddelta, delta_proj_weight)
            dconv1d_out = rearrange(dconv1d_out, "b d l -> d (b l)")
            dx_proj_weight = torch.einsum("Br,Bd->rd", dx_dbl, rearrange(conv1d_out, "b d l -> (b l) d"))
            dconv1d_out = torch.addmm(dconv1d_out, x_proj_weight.t(), dx_dbl.t(), out=dconv1d_out)
            dconv1d_out = rearrange(dconv1d_out, "d (b l) -> b d l", b=x.shape[0], l=x.shape[-1])
            # The kernel supports passing in a pre-allocated dx (e.g., in case we want to fuse the
            # backward of conv1d with the backward of chunk).
            dx, dconv1d_weight, dconv1d_bias = causal_conv1d_cuda.causal_conv1d_bwd(
                x, conv1d_weight, conv1d_bias, dconv1d_out, None, dx, True
            )
            dconv1d_bias = dconv1d_bias if conv1d_bias is not None else None
            dconv1d_weight = rearrange(dconv1d_weight, "d w -> d 1 w")
            return (dxz, dconv1d_weight, dconv1d_bias, dx_proj_weight, ddelta_proj_weight,
                    dA, dB, dC, dD,
                    ddelta_bias if delta_bias is not None else None,
                    dB_proj_bias, dC_proj_bias, None)
    
    
    class MambaInnerFn(torch.autograd.Function):
    
        @staticmethod
        @custom_fwd
        def forward(ctx, xz, conv1d_weight, conv1d_bias, x_proj_weight, delta_proj_weight,
                    out_proj_weight, out_proj_bias,
                    A, B=None, C=None, D=None, delta_bias=None, B_proj_bias=None,
                    C_proj_bias=None, delta_softplus=True, checkpoint_lvl=1):
            """
                 xz: (batch, dim, seqlen)
            """
            assert causal_conv1d_cuda is not None, "causal_conv1d_cuda is not available. Please install causal-conv1d."
            assert checkpoint_lvl in [0, 1]
            L = xz.shape[-1]
            delta_rank = delta_proj_weight.shape[1]
            d_state = A.shape[-1] * (1 if not A.is_complex() else 2)
            if torch.is_autocast_enabled():
                x_proj_weight = x_proj_weight.to(dtype=torch.get_autocast_gpu_dtype())
                delta_proj_weight = delta_proj_weight.to(dtype=torch.get_autocast_gpu_dtype())
                out_proj_weight = out_proj_weight.to(dtype=torch.get_autocast_gpu_dtype())
                out_proj_bias = (out_proj_bias.to(dtype=torch.get_autocast_gpu_dtype())
                                 if out_proj_bias is not None else None)
            if xz.stride(-1) != 1:
                xz = xz.contiguous()
            conv1d_weight = rearrange(conv1d_weight, "d 1 w -> d w")
            x, z = xz.chunk(2, dim=1)
            conv1d_bias = conv1d_bias.contiguous() if conv1d_bias is not None else None
            conv1d_out = causal_conv1d_cuda.causal_conv1d_fwd(
                x, conv1d_weight, conv1d_bias, None, None, None, True
            )
            # We're being very careful here about the layout, to avoid extra transposes.
            # We want delta to have d as the slowest moving dimension
            # and L as the fastest moving dimension, since those are what the ssm_scan kernel expects.
            x_dbl = F.linear(rearrange(conv1d_out, 'b d l -> (b l) d'), x_proj_weight)  # (bl d)
            delta = rearrange(delta_proj_weight @ x_dbl[:, :delta_rank].t(), "d (b l) -> b d l", l=L)
            ctx.is_variable_B = B is None
            ctx.is_variable_C = C is None
            ctx.B_proj_bias_is_None = B_proj_bias is None
            ctx.C_proj_bias_is_None = C_proj_bias is None
            if B is None:  # variable B
                B = x_dbl[:, delta_rank:delta_rank + d_state]  # (bl dstate)
                if B_proj_bias is not None:
                    B = B + B_proj_bias.to(dtype=B.dtype)
                if not A.is_complex():
                    # B = rearrange(B, "(b l) dstate -> b dstate l", l=L).contiguous()
                    B = rearrange(B, "(b l) dstate -> b 1 dstate l", l=L).contiguous()
                else:
                    B = rearrange(B, "(b l) (dstate two) -> b 1 dstate (l two)", l=L, two=2).contiguous()
            else:
                if B.stride(-1) != 1:
                    B = B.contiguous()
            if C is None:  # variable C
                C = x_dbl[:, -d_state:]  # (bl dstate)
                if C_proj_bias is not None:
                    C = C + C_proj_bias.to(dtype=C.dtype)
                if not A.is_complex():
                    # C = rearrange(C, "(b l) dstate -> b dstate l", l=L).contiguous()
                    C = rearrange(C, "(b l) dstate -> b 1 dstate l", l=L).contiguous()
                else:
                    C = rearrange(C, "(b l) (dstate two) -> b 1 dstate (l two)", l=L, two=2).contiguous()
            else:
                if C.stride(-1) != 1:
                    C = C.contiguous()
            if D is not None:
                D = D.contiguous()
            out, scan_intermediates, out_z = selective_scan_cuda.fwd(
                conv1d_out, delta, A, B, C, D, z, delta_bias, delta_softplus
            )
            ctx.delta_softplus = delta_softplus
            ctx.out_proj_bias_is_None = out_proj_bias is None
            ctx.checkpoint_lvl = checkpoint_lvl
            if checkpoint_lvl >= 1:  # Will recompute conv1d_out and delta in the backward pass
                conv1d_out, delta = None, None
            ctx.save_for_backward(xz, conv1d_weight, conv1d_bias, x_dbl, x_proj_weight,
                                  delta_proj_weight, out_proj_weight, conv1d_out, delta,
                                  A, B, C, D, delta_bias, scan_intermediates, out)
            return F.linear(rearrange(out_z, "b d l -> b l d"), out_proj_weight, out_proj_bias)
    
        @staticmethod
        @custom_bwd
        def backward(ctx, dout):
            # dout: (batch, seqlen, dim)
            assert causal_conv1d_cuda is not None, "causal_conv1d_cuda is not available. Please install causal-conv1d."
            (xz, conv1d_weight, conv1d_bias, x_dbl, x_proj_weight, delta_proj_weight, out_proj_weight,
             conv1d_out, delta, A, B, C, D, delta_bias, scan_intermediates, out) = ctx.saved_tensors
            L = xz.shape[-1]
            delta_rank = delta_proj_weight.shape[1]
            d_state = A.shape[-1] * (1 if not A.is_complex() else 2)
            x, z = xz.chunk(2, dim=1)
            if dout.stride(-1) != 1:
                dout = dout.contiguous()
            if ctx.checkpoint_lvl == 1:
                conv1d_out = causal_conv1d_cuda.causal_conv1d_fwd(
                    x, conv1d_weight, conv1d_bias, None, None, None, True
                )
                delta = rearrange(delta_proj_weight @ x_dbl[:, :delta_rank].t(),
                                  "d (b l) -> b d l", l=L)
            # The kernel supports passing in a pre-allocated dz (e.g., in case we want to fuse the
            # backward of selective_scan_cuda with the backward of chunk).
            dxz = torch.empty_like(xz)  # (batch, dim, seqlen)
            dx, dz = dxz.chunk(2, dim=1)
            dout = rearrange(dout, "b l e -> e (b l)")
            dout_y = rearrange(out_proj_weight.t() @ dout, "d (b l) -> b d l", l=L)
            dconv1d_out, ddelta, dA, dB, dC, dD, ddelta_bias, dz, out_z = selective_scan_cuda.bwd(
                conv1d_out, delta, A, B, C, D, z, delta_bias, dout_y, scan_intermediates, out, dz,
                ctx.delta_softplus,
                True  # option to recompute out_z
            )
            dout_proj_weight = torch.einsum("eB,dB->ed", dout, rearrange(out_z, "b d l -> d (b l)"))
            dout_proj_bias = dout.sum(dim=(0, 1)) if not ctx.out_proj_bias_is_None else None
            dD = dD if D is not None else None
            dx_dbl = torch.empty_like(x_dbl)
            dB_proj_bias = None
            if ctx.is_variable_B:
                if not A.is_complex():
                    dB = rearrange(dB, "b 1 dstate l -> (b l) dstate").contiguous()
                else:
                    dB = rearrange(dB, "b 1 dstate (l two) -> (b l) (dstate two)", two=2).contiguous()
                dB_proj_bias = dB.sum(0) if not ctx.B_proj_bias_is_None else None
                dx_dbl[:, delta_rank:delta_rank + d_state] = dB  # (bl d)
                dB = None
            dC_proj_bias = None
            if ctx.is_variable_C:
                if not A.is_complex():
                    dC = rearrange(dC, "b 1 dstate l -> (b l) dstate").contiguous()
                else:
                    dC = rearrange(dC, "b 1 dstate (l two) -> (b l) (dstate two)", two=2).contiguous()
                dC_proj_bias = dC.sum(0) if not ctx.C_proj_bias_is_None else None
                dx_dbl[:, -d_state:] = dC  # (bl d)
                dC = None
            ddelta = rearrange(ddelta, "b d l -> d (b l)")
            ddelta_proj_weight = torch.einsum("dB,Br->dr", ddelta, x_dbl[:, :delta_rank])
            dx_dbl[:, :delta_rank] = torch.einsum("dB,dr->Br", ddelta, delta_proj_weight)
            dconv1d_out = rearrange(dconv1d_out, "b d l -> d (b l)")
            dx_proj_weight = torch.einsum("Br,Bd->rd", dx_dbl, rearrange(conv1d_out, "b d l -> (b l) d"))
            dconv1d_out = torch.addmm(dconv1d_out, x_proj_weight.t(), dx_dbl.t(), out=dconv1d_out)
            dconv1d_out = rearrange(dconv1d_out, "d (b l) -> b d l", b=x.shape[0], l=x.shape[-1])
            # The kernel supports passing in a pre-allocated dx (e.g., in case we want to fuse the
            # backward of conv1d with the backward of chunk).
            dx, dconv1d_weight, dconv1d_bias, *_ = causal_conv1d_cuda.causal_conv1d_bwd(
                x, conv1d_weight, conv1d_bias, dconv1d_out, None, None, None, dx, False, True
            )
            dconv1d_bias = dconv1d_bias if conv1d_bias is not None else None
            dconv1d_weight = rearrange(dconv1d_weight, "d w -> d 1 w")
            return (dxz, dconv1d_weight, dconv1d_bias, dx_proj_weight, ddelta_proj_weight,
                    dout_proj_weight, dout_proj_bias,
                    dA, dB, dC, dD,
                    ddelta_bias if delta_bias is not None else None,
                    dB_proj_bias, dC_proj_bias, None)
    
    
    class BiMambaInnerFn(torch.autograd.Function):
    
        @staticmethod
        @custom_fwd
        def forward(ctx, xz, conv1d_weight, conv1d_bias, x_proj_weight, delta_proj_weight,
                    out_proj_weight, out_proj_bias,
                    A, A_b, B=None, C=None, D=None, delta_bias=None, B_proj_bias=None,
                    C_proj_bias=None, delta_softplus=True, checkpoint_lvl=1):
            """
                 xz: (batch, dim, seqlen)
            """
            assert checkpoint_lvl in [0, 1]
            L = xz.shape[-1]
            delta_rank = delta_proj_weight.shape[1]
            d_state = A.shape[-1] * (1 if not A.is_complex() else 2)
            if torch.is_autocast_enabled():
                x_proj_weight = x_proj_weight.to(dtype=torch.get_autocast_gpu_dtype())
                delta_proj_weight = delta_proj_weight.to(dtype=torch.get_autocast_gpu_dtype())
                out_proj_weight = out_proj_weight.to(dtype=torch.get_autocast_gpu_dtype())
                out_proj_bias = (out_proj_bias.to(dtype=torch.get_autocast_gpu_dtype())
                                 if out_proj_bias is not None else None)
            if xz.stride(-1) != 1:
                xz = xz.contiguous()
            conv1d_weight = rearrange(conv1d_weight, "d 1 w -> d w")
            x, z = xz.chunk(2, dim=1)
            conv1d_bias = conv1d_bias.contiguous() if conv1d_bias is not None else None
            conv1d_out = causal_conv1d_cuda.causal_conv1d_fwd(x, conv1d_weight, conv1d_bias, None, True)
            # We're being very careful here about the layout, to avoid extra transposes.
            # We want delta to have d as the slowest moving dimension
            # and L as the fastest moving dimension, since those are what the ssm_scan kernel expects.
            x_dbl = F.linear(rearrange(conv1d_out, 'b d l -> (b l) d'), x_proj_weight)  # (bl d)
            delta = rearrange(delta_proj_weight @ x_dbl[:, :delta_rank].t(), "d (b l) -> b d l", l=L)
            ctx.is_variable_B = B is None
            ctx.is_variable_C = C is None
            ctx.B_proj_bias_is_None = B_proj_bias is None
            ctx.C_proj_bias_is_None = C_proj_bias is None
            if B is None:  # variable B
                B = x_dbl[:, delta_rank:delta_rank + d_state]  # (bl dstate)
                if B_proj_bias is not None:
                    B = B + B_proj_bias.to(dtype=B.dtype)
                if not A.is_complex():
                    # B = rearrange(B, "(b l) dstate -> b dstate l", l=L).contiguous()
                    B = rearrange(B, "(b l) dstate -> b 1 dstate l", l=L).contiguous()
                else:
                    B = rearrange(B, "(b l) (dstate two) -> b 1 dstate (l two)", l=L, two=2).contiguous()
            else:
                if B.stride(-1) != 1:
                    B = B.contiguous()
            if C is None:  # variable C
                C = x_dbl[:, -d_state:]  # (bl dstate)
                if C_proj_bias is not None:
                    C = C + C_proj_bias.to(dtype=C.dtype)
                if not A.is_complex():
                    # C = rearrange(C, "(b l) dstate -> b dstate l", l=L).contiguous()
                    C = rearrange(C, "(b l) dstate -> b 1 dstate l", l=L).contiguous()
                else:
                    C = rearrange(C, "(b l) (dstate two) -> b 1 dstate (l two)", l=L, two=2).contiguous()
            else:
                if C.stride(-1) != 1:
                    C = C.contiguous()
            if D is not None:
                D = D.contiguous()
            out_f, scan_intermediates_f, out_z_f = selective_scan_cuda.fwd(
                conv1d_out, delta, A, B, C, D, z, delta_bias, delta_softplus
            )
            assert not A_b.is_complex(), "A should not be complex!!"
            out_b, scan_intermediates_b, out_z_b = selective_scan_cuda.fwd(
                conv1d_out.flip([-1]), delta.flip([-1]), A_b, B.flip([-1]), C.flip([-1]), D, z.flip([-1]), delta_bias,
                delta_softplus,
            )
    
            out_z = out_z_f + out_z_b.flip([-1])
    
            ctx.delta_softplus = delta_softplus
            ctx.out_proj_bias_is_None = out_proj_bias is None
            ctx.checkpoint_lvl = checkpoint_lvl
            if checkpoint_lvl >= 1:  # Will recompute conv1d_out and delta in the backward pass
                conv1d_out, delta = None, None
            ctx.save_for_backward(xz, conv1d_weight, conv1d_bias, x_dbl, x_proj_weight,
                                  delta_proj_weight, out_proj_weight, conv1d_out, delta,
                                  A, A_b, B, C, D, delta_bias, scan_intermediates_f, scan_intermediates_b, out_f, out_b)
            return F.linear(rearrange(out_z, "b d l -> b l d"), out_proj_weight, out_proj_bias)
    
        @staticmethod
        @custom_bwd
        def backward(ctx, dout):
            # dout: (batch, seqlen, dim)
            (xz, conv1d_weight, conv1d_bias, x_dbl, x_proj_weight, delta_proj_weight, out_proj_weight,
             conv1d_out, delta, A, A_b, B, C, D, delta_bias, scan_intermediates_f, scan_intermediates_b, out_f,
             out_b) = ctx.saved_tensors
            L = xz.shape[-1]
            delta_rank = delta_proj_weight.shape[1]
            d_state = A.shape[-1] * (1 if not A.is_complex() else 2)
            x, z = xz.chunk(2, dim=1)
            if dout.stride(-1) != 1:
                dout = dout.contiguous()
            if ctx.checkpoint_lvl == 1:
                conv1d_out = causal_conv1d_cuda.causal_conv1d_fwd(x, conv1d_weight, conv1d_bias, None, True)
                delta = rearrange(delta_proj_weight @ x_dbl[:, :delta_rank].t(),
                                  "d (b l) -> b d l", l=L)
            # The kernel supports passing in a pre-allocated dz (e.g., in case we want to fuse the
            # backward of selective_scan_cuda with the backward of chunk).
            dxz = torch.empty_like(xz)  # (batch, dim, seqlen)
            dx, dz = dxz.chunk(2, dim=1)
            dout = rearrange(dout, "b l e -> e (b l)")
            dout_y = rearrange(out_proj_weight.t() @ dout, "d (b l) -> b d l", l=L)
            dconv1d_out, ddelta, dA, dB, dC, dD, ddelta_bias, dz, out_z_f = selective_scan_cuda.bwd(
                conv1d_out, delta, A, B, C, D, z, delta_bias, dout_y, scan_intermediates_f, out_f, dz,
                ctx.delta_softplus,
                True  # option to recompute out_z
            )
            # flip one
            dz_b = torch.empty_like(dz)
            dconv1d_out_f_b, ddelta_f_b, dA_b, dB_f_b, dC_f_b, dD_b, ddelta_bias_b, dz_b, out_z_b = selective_scan_cuda.bwd(
                conv1d_out.flip([-1]), delta.flip([-1]), A_b, B.flip([-1]), C.flip([-1]), D, z.flip([-1]), delta_bias,
                dout_y.flip([-1]), scan_intermediates_b, out_b, dz_b,
                ctx.delta_softplus,
                True  # option to recompute out_z
            )
    
            dconv1d_out = dconv1d_out + dconv1d_out_f_b.flip([-1])
            ddelta = ddelta + ddelta_f_b.flip([-1])
            dB = dB + dB_f_b.flip([-1])
            dC = dC + dC_f_b.flip([-1])
            dD = dD + dD_b
            ddelta_bias = ddelta_bias + ddelta_bias_b
            dz = dz + dz_b.flip([-1])
            out_z = out_z_f + out_z_b.flip([-1])
    
            dout_proj_weight = torch.einsum("eB,dB->ed", dout, rearrange(out_z, "b d l -> d (b l)"))
            dout_proj_bias = dout.sum(dim=(0, 1)) if not ctx.out_proj_bias_is_None else None
            dD = dD if D is not None else None
            dx_dbl = torch.empty_like(x_dbl)
            dB_proj_bias = None
            if ctx.is_variable_B:
                if not A.is_complex():
                    dB = rearrange(dB, "b 1 dstate l -> (b l) dstate").contiguous()
                else:
                    dB = rearrange(dB, "b 1 dstate (l two) -> (b l) (dstate two)", two=2).contiguous()
                dB_proj_bias = dB.sum(0) if not ctx.B_proj_bias_is_None else None
                dx_dbl[:, delta_rank:delta_rank + d_state] = dB  # (bl d)
                dB = None
            dC_proj_bias = None
            if ctx.is_variable_C:
                if not A.is_complex():
                    dC = rearrange(dC, "b 1 dstate l -> (b l) dstate").contiguous()
                else:
                    dC = rearrange(dC, "b 1 dstate (l two) -> (b l) (dstate two)", two=2).contiguous()
                dC_proj_bias = dC.sum(0) if not ctx.C_proj_bias_is_None else None
                dx_dbl[:, -d_state:] = dC  # (bl d)
                dC = None
            ddelta = rearrange(ddelta, "b d l -> d (b l)")
            ddelta_proj_weight = torch.einsum("dB,Br->dr", ddelta, x_dbl[:, :delta_rank])
            dx_dbl[:, :delta_rank] = torch.einsum("dB,dr->Br", ddelta, delta_proj_weight)
            dconv1d_out = rearrange(dconv1d_out, "b d l -> d (b l)")
            dx_proj_weight = torch.einsum("Br,Bd->rd", dx_dbl, rearrange(conv1d_out, "b d l -> (b l) d"))
            dconv1d_out = torch.addmm(dconv1d_out, x_proj_weight.t(), dx_dbl.t(), out=dconv1d_out)
            dconv1d_out = rearrange(dconv1d_out, "d (b l) -> b d l", b=x.shape[0], l=x.shape[-1])
            # The kernel supports passing in a pre-allocated dx (e.g., in case we want to fuse the
            # backward of conv1d with the backward of chunk).
            dx, dconv1d_weight, dconv1d_bias = causal_conv1d_cuda.causal_conv1d_bwd(
                x, conv1d_weight, conv1d_bias, dconv1d_out, None, dx, True
            )
            dconv1d_bias = dconv1d_bias if conv1d_bias is not None else None
            dconv1d_weight = rearrange(dconv1d_weight, "d w -> d 1 w")
            return (dxz, dconv1d_weight, dconv1d_bias, dx_proj_weight, ddelta_proj_weight,
                    dout_proj_weight, dout_proj_bias,
                    dA, dA_b, dB, dC, dD,
                    ddelta_bias if delta_bias is not None else None,
                    dB_proj_bias, dC_proj_bias, None)
    
    
    def mamba_inner_fn(
            xz, conv1d_weight, conv1d_bias, x_proj_weight, delta_proj_weight,
            out_proj_weight, out_proj_bias,
            A, B=None, C=None, D=None, delta_bias=None, B_proj_bias=None,
            C_proj_bias=None, delta_softplus=True
    ):
        return mamba_inner_ref(xz, conv1d_weight, conv1d_bias, x_proj_weight, delta_proj_weight,
                               out_proj_weight, out_proj_bias,
                               A, B, C, D, delta_bias, B_proj_bias, C_proj_bias, delta_softplus)
    
    
    def bimamba_inner_fn(
            xz, conv1d_weight, conv1d_bias, x_proj_weight, delta_proj_weight,
            out_proj_weight, out_proj_bias,
            A, A_b, B=None, C=None, D=None, delta_bias=None, B_proj_bias=None,
            C_proj_bias=None, delta_softplus=True
    ):
        return bimamba_inner_ref(xz, conv1d_weight, conv1d_bias, x_proj_weight, delta_proj_weight,
                                 out_proj_weight, out_proj_bias,
                                 A, A_b, B, C, D, delta_bias, B_proj_bias, C_proj_bias, delta_softplus)
    
    
    def mamba_inner_fn_no_out_proj(
            xz, conv1d_weight, conv1d_bias, x_proj_weight, delta_proj_weight,
            A, B=None, C=None, D=None, delta_bias=None, B_proj_bias=None,
            C_proj_bias=None, delta_softplus=True
    ):
        return mamba_inner_ref_fn_no_out_proj(xz, conv1d_weight, conv1d_bias, x_proj_weight, delta_proj_weight,
                                              A, B, C, D, delta_bias, B_proj_bias, C_proj_bias, delta_softplus)
    
    
    def mamba_inner_ref(
            xz, conv1d_weight, conv1d_bias, x_proj_weight, delta_proj_weight,
            out_proj_weight, out_proj_bias,
            A, B=None, C=None, D=None, delta_bias=None, B_proj_bias=None,
            C_proj_bias=None, delta_softplus=True
    ):
        L = xz.shape[-1]
        delta_rank = delta_proj_weight.shape[1]
        d_state = A.shape[-1] * (1 if not A.is_complex() else 2)
        x, z = xz.chunk(2, dim=1)
        x = causal_conv1d_fn(x, rearrange(conv1d_weight, "d 1 w -> d w"), conv1d_bias, "silu")
        # We're being very careful here about the layout, to avoid extra transposes.
        # We want delta to have d as the slowest moving dimension
        # and L as the fastest moving dimension, since those are what the ssm_scan kernel expects.
        x_dbl = F.linear(rearrange(x, 'b d l -> (b l) d'), x_proj_weight)  # (bl d)
        delta = delta_proj_weight @ x_dbl[:, :delta_rank].t()
        delta = rearrange(delta, "d (b l) -> b d l", l=L)
        if B is None:  # variable B
            B = x_dbl[:, delta_rank:delta_rank + d_state]  # (bl d)
            if B_proj_bias is not None:
                B = B + B_proj_bias.to(dtype=B.dtype)
            if not A.is_complex():
                B = rearrange(B, "(b l) dstate -> b dstate l", l=L).contiguous()
            else:
                B = rearrange(B, "(b l) (dstate two) -> b dstate (l two)", l=L, two=2).contiguous()
        if C is None:  # variable B
            C = x_dbl[:, -d_state:]  # (bl d)
            if C_proj_bias is not None:
                C = C + C_proj_bias.to(dtype=C.dtype)
            if not A.is_complex():
                C = rearrange(C, "(b l) dstate -> b dstate l", l=L).contiguous()
            else:
                C = rearrange(C, "(b l) (dstate two) -> b dstate (l two)", l=L, two=2).contiguous()
        y = selective_scan_fn(x, delta, A, B, C, D, z=z, delta_bias=delta_bias, delta_softplus=True)
        return F.linear(rearrange(y, "b d l -> b l d"), out_proj_weight, out_proj_bias)
    
    
    def mamba_inner_ref_fn_no_out_proj(
            xz, conv1d_weight, conv1d_bias, x_proj_weight, delta_proj_weight,
            A, B=None, C=None, D=None, delta_bias=None, B_proj_bias=None,
            C_proj_bias=None, delta_softplus=True
    ):
        L = xz.shape[-1]
        delta_rank = delta_proj_weight.shape[1]
        d_state = A.shape[-1] * (1 if not A.is_complex() else 2)
        x, z = xz.chunk(2, dim=1)
        x = causal_conv1d_fn(x, rearrange(conv1d_weight, "d 1 w -> d w"), conv1d_bias, "silu")
        # We're being very careful here about the layout, to avoid extra transposes.
        # We want delta to have d as the slowest moving dimension
        # and L as the fastest moving dimension, since those are what the ssm_scan kernel expects.
        x_dbl = F.linear(rearrange(x, 'b d l -> (b l) d'), x_proj_weight)  # (bl d)
        delta = delta_proj_weight @ x_dbl[:, :delta_rank].t()
        delta = rearrange(delta, "d (b l) -> b d l", l=L)
        if B is None:  # variable B
            B = x_dbl[:, delta_rank:delta_rank + d_state]  # (bl d)
            if B_proj_bias is not None:
                B = B + B_proj_bias.to(dtype=B.dtype)
            if not A.is_complex():
                B = rearrange(B, "(b l) dstate -> b dstate l", l=L).contiguous()
            else:
                B = rearrange(B, "(b l) (dstate two) -> b dstate (l two)", l=L, two=2).contiguous()
        if C is None:  # variable B
            C = x_dbl[:, -d_state:]  # (bl d)
            if C_proj_bias is not None:
                C = C + C_proj_bias.to(dtype=C.dtype)
            if not A.is_complex():
                C = rearrange(C, "(b l) dstate -> b dstate l", l=L).contiguous()
            else:
                C = rearrange(C, "(b l) (dstate two) -> b dstate (l two)", l=L, two=2).contiguous()
        y = selective_scan_fn(x, delta, A, B, C, D, z=z, delta_bias=delta_bias, delta_softplus=True)
        # return F.linear(rearrange(y, "b d l -> b l d"), out_proj_weight, out_proj_bias)
        return y
    
    
    def bimamba_inner_ref(
            xz, conv1d_weight, conv1d_bias, x_proj_weight, delta_proj_weight,
            out_proj_weight, out_proj_bias,
            A, A_b, B=None, C=None, D=None, delta_bias=None, B_proj_bias=None,
            C_proj_bias=None, delta_softplus=True
    ):
        L = xz.shape[-1]
        delta_rank = delta_proj_weight.shape[1]
        d_state = A.shape[-1] * (1 if not A.is_complex() else 2)
        x, z = xz.chunk(2, dim=1)
        x = causal_conv1d_fn(x, rearrange(conv1d_weight, "d 1 w -> d w"), conv1d_bias, "silu")
        # We're being very careful here about the layout, to avoid extra transposes.
        # We want delta to have d as the slowest moving dimension
        # and L as the fastest moving dimension, since those are what the ssm_scan kernel expects.
        x_dbl = F.linear(rearrange(x, 'b d l -> (b l) d'), x_proj_weight)  # (bl d)
        delta = delta_proj_weight @ x_dbl[:, :delta_rank].t()
        delta = rearrange(delta, "d (b l) -> b d l", l=L)
        if B is None:  # variable B
            B = x_dbl[:, delta_rank:delta_rank + d_state]  # (bl d)
            if B_proj_bias is not None:
                B = B + B_proj_bias.to(dtype=B.dtype)
            if not A.is_complex():
                B = rearrange(B, "(b l) dstate -> b dstate l", l=L).contiguous()
            else:
                B = rearrange(B, "(b l) (dstate two) -> b dstate (l two)", l=L, two=2).contiguous()
        if C is None:  # variable B
            C = x_dbl[:, -d_state:]  # (bl d)
            if C_proj_bias is not None:
                C = C + C_proj_bias.to(dtype=C.dtype)
            if not A.is_complex():
                C = rearrange(C, "(b l) dstate -> b dstate l", l=L).contiguous()
            else:
                C = rearrange(C, "(b l) (dstate two) -> b dstate (l two)", l=L, two=2).contiguous()
        y = selective_scan_fn(x, delta, A, B, C, D, z=z, delta_bias=delta_bias, delta_softplus=True)
        y_b = selective_scan_fn(x.flip([-1]), delta.flip([-1]), A_b, B.flip([-1]), C.flip([-1]), D, z.flip([-1]),
                                delta_bias, delta_softplus=True)
        y = y + y_b.flip([-1])
        return F.linear(rearrange(y, "b d l -> b l d"), out_proj_weight, out_proj_bias)
    

20240506 更新

在用 pip install mamba-ssm 安装完 mamba-ssm 发现原来正常运行的代码出现以下报错:

File "/home/xxx/.conda/envs/mamba/lib/python3.10/site-packages/mamba_ssm/ops/selective_scan_interface.py", line 187, in forward
    conv1d_out = causal_conv1d_cuda.causal_conv1d_fwd(
TypeError: causal_conv1d_fwd(): incompatible function arguments. The following argument types are supported:
    1. (arg0: torch.Tensor, arg1: torch.Tensor, arg2: Optional[torch.Tensor], arg3: Optional[torch.Tensor], arg4: bool) -> torch.Tensor

Invoked with: tensor(
        [-4.9056e-40, -4.9057e-40, -4.9074e-40, -4.9078e-40]], device='cuda:0',
       requires_grad=True), Parameter containing:
tensor([ 0.0322, -0.1139,  0.0770,  ..., -0.0320, -0.1266, -0.1096],
       device='cuda:0', requires_grad=True), None, None, None, True

经过检查发现是 mamba-ssm 版本的问题,报错的版本号为 1.2.0.post1,即 pip install mamba-ssm 安装的是最新版本,与之前的函数存在部分不兼容,而之前正常运行版本为 1.1.3.post1

20240523 更新

有小伙伴在跑 Vision Mamba 时遭遇以下报错(Linux):

TypeError: Mamba.__init__() got an unexpected keyword argument 'bimamba_type'

因为 Vision Mamba 修改了 Mamba 的源代码,从 Mamba 官方途径安装的包中是没有这个函数的,所以需要先卸载原版Mamba ,再从 Vision Mamba 代码里的Mamba 源码手动安装,而不是从 Mamba 官方途径安装。不过实测也可以直接进行文件替换,用 Vision Mambaselective_scan_interface.py 替换 selective_scan_interface.py,替换 causal_conv1d_interface.pymamba_simple.py

20240531 更新

最近有小伙伴在安装时出现以下报错:ImportError: cannot import name 'packaging' from 'pkg_resources,原因是 setuptools 版本太高,一般是70.0.0,需要降级,直接 pip install setuptools==68.2.2 即可。

20240604 更新

1. Linux 下 找不到 selective_scan_cuda

有小伙伴在配置 Vision Mamba 时遇到以下错误:
NameError: name 'selective_scan_cuda' is not defined. Did you mean: 'selective_scan_fn',出现该问题的原因是压根没按照原文所说安装 causal_conv1dmamba_ssm,直接复制的源码。不过按照原文指示的安装方法,大概会报错,在此提供一个更简单的方法:

  1. 按照本文 解决方案(Linux)三正常安装原版 causal_conv1dmamba_ssm
  2. 直接用 Vision Mamba 工程下的 causal_conv1dmamba_ssm 替换环境中已经装好的对应位置的 causal_conv1dmamba_ssm

2. Win 下 ModuleNotFoundError: No module named 'causal_conv1d_cuda'

这其实是一个老问题,参看本文 20240424更新 第一个问题,原因是 causal_conv1d 没有装成功,正常来说装成功之后在位置 xxxx\envs\xxxx\Lib\site-packages\ 下面有一个 causal_conv1d_cuda.cp310-win_amd64.pyd (以本人环境为例)文件,该文件下载链接为causal-conv1d-cuda.cp310-win-amd64.pyd。可以按照本文前述方法在Windows下面从源码编译,或者在配置好前面环境后直接下载本人编译好的whl安装。【文件均可联系本人vx自取】

20240607 更新

1. 关于在Windows系统下Vim的环境配置问题

先按照前文所述的 解决方案(Win) 配置好Mamba,再根据Vision Mamba 的源码相应地修改安装好的mamba包的源码。参考前文 20240424更新causal_conv1d_cuda for Vim 以及 selective_scan_cuda for Vim

2. Win下面跑Vim出现 KeyError: 'HOME'

具体来说出现以下报错

Traceback (most recent call last):
  .....
  File "xxx\models\vimamba.py", line 115, in forward
    hidden_states, residual = fused_add_norm_fn(
  File "D:\Anaconda\envs\xxx\lib\site-packages\mamba_ssm\ops\triton\layernorm.py", line 478, in rms_norm_fn
    return LayerNormFn.apply(x, weight, bias, residual, eps, prenorm, residual_in_fp32, True)
  File "D:\Anaconda\envs\xxx\lib\site-packages\torch\autograd\function.py", line 539, in apply
    return super().apply(*args, **kwargs)  # type: ignore[misc]
  File "D:\Anaconda\envs\xxx\lib\site-packages\mamba_ssm\ops\triton\layernorm.py", line 411, in forward
    y, mean, rstd, residual_out = _layer_norm_fwd(
  File "D:\Anaconda\envs\xxx\lib\site-packages\mamba_ssm\ops\triton\layernorm.py", line 155, in _layer_norm_fwd
    _layer_norm_fwd_1pass_kernel[(M,)](
  File "D:\Anaconda\envs\xxx\lib\site-packages\triton\runtime\jit.py", line 106, in launcher
    return self.run(*args, grid=grid, **kwargs)
  File "D:\Anaconda\envs\xxx\lib\site-packages\triton\runtime\autotuner.py", line 73, in run
    timings = {config: self._bench(*args, config=config, **kwargs)
  File "D:\Anaconda\envs\xxx\lib\site-packages\triton\runtime\autotuner.py", line 73, in <dictcomp>
    timings = {config: self._bench(*args, config=config, **kwargs)
  File "D:\Anaconda\envs\xxx\lib\site-packages\triton\runtime\autotuner.py", line 63, in _bench
    return do_bench(kernel_call)
  File "D:\Anaconda\envs\xxx\lib\site-packages\triton\testing.py", line 136, in do_bench
    fn()
  File "D:\Anaconda\envs\xxx\lib\site-packages\triton\runtime\autotuner.py", line 62, in kernel_call
    self.fn.run(*args, num_warps=config.num_warps, num_stages=config.num_stages, **current)
  File "<string>", line 41, in _layer_norm_fwd_1pass_kernel
  File "D:\Anaconda\envs\xxx\lib\site-packages\triton\compiler.py", line 1230, in compile
    so_cache_manager = CacheManager(so_cache_key)
  File "D:\Anaconda\envs\xxx\lib\site-packages\triton\compiler.py", line 1102, in __init__
    self.cache_dir = os.environ.get('TRITON_CACHE_DIR', default_cache_dir())
  File "D:\Anaconda\envs\xxx\lib\site-packages\triton\compiler.py", line 1093, in default_cache_dir
    return os.path.join(os.environ["HOME"], ".triton", "cache")
  File "D:\Anaconda\envs\xxx\lib\os.py", line 680, in __getitem__
    raise KeyError(key) from None
KeyError: 'HOME'

在Win下还需要修改 mamba 安装路径下 D:\Anaconda\envs\xxx\lib\site-packages\mamba_ssm\ops\triton\layernorm.py 文件,具体来说,是把原来 layernorm.py 里面的

def layer_norm_fn(
    x,
    weight,
    bias,
    residual=None,
    eps=1e-6,
    prenorm=False,
    residual_in_fp32=False,
    is_rms_norm=False,
):
    return LayerNormFn.apply(x, weight, bias, residual, eps, prenorm, residual_in_fp32, is_rms_norm)


def rms_norm_fn(x, weight, bias, residual=None, prenorm=False, residual_in_fp32=False, eps=1e-6):
    return LayerNormFn.apply(x, weight, bias, residual, eps, prenorm, residual_in_fp32, True)

改为

def layer_norm_fn(
    x,
    weight,
    bias,
    residual=None,
    eps=1e-6,
    prenorm=False,
    residual_in_fp32=False,
    is_rms_norm=False,
):
    return layer_norm_ref(x, weight, bias, residual, eps, prenorm, residual_in_fp32)


def rms_norm_fn(x, weight, bias, residual=None, prenorm=False, residual_in_fp32=False, eps=1e-6):
    return rms_norm_ref(x, weight, bias, residual, eps, prenorm, residual_in_fp32)

3. Win 下配置Vim环境总结

Vim源码通过setup.py 分别安装 causal_conv1dmamba_ssm 之前

  1. 修改 Vim/mamba-1p1p1/mamba_ssm/ops/triton/layernorm.py
  2. 修改 Vim/mamba-1p1p1/mamba_ssm/ops/selective_scan_interface.py
  3. 修改 Vim/causal-conv1d/causal_conv1d/causal_conv1d_interface.py

也可以通过setup.py 分别安装 causal_conv1dmamba_ssm 之后

  1. 修改 xxx\Anaconda\envs\xxx\Lib\site-packages\mamba_ssm\ops\triton\layernorm.py
  2. 修改 xxx\Anaconda\envs\xxx\Lib\site-packages\mamba_ssm\ops\selective_scan_interface.py
  3. 修改 xxx\Anaconda\envs\xxx\Lib\site-packages\causal_conv1d\causal_conv1d_interface.py

如果是通过前文所述的 解决方案(Win) 配置好Mamba,再跑 Vim,除了修改这三处源码之外则还需要用Vim源码 中的 mamba_simple.py 文件 替换 xxx\Anaconda\envs\xxx\Lib\site-packages\mamba_ssm\modules\mamba_simple.py

20240715 更新

  1. 经过实验,Windows 下 Mamba 和 Vmamba 均可正常编译,无需绕过 selective_scan_cuda,安装步骤参考开头导航的最新博客,Vim 依旧可以先装好Mamba再覆盖虚拟环境的源码。
  2. 经网友提醒,所有下载积分增至5,不鼓励从csdn下载;正确的安装步骤、可能的报错解决方案及参考的资料在系列博客中均已公开,符合开源精神,倡导亲自动手实践。
  3. 实在不行,安装问题(已仔细阅读系列博客且没有找到解决方法) / 资源自取 / idea合作 请联系主页 vx,私信评论数量过多且有限制,随缘看。

  1. https://github.com/bowang-lab/U-Mamba ↩︎

  2. https://github.com/hustvl/Vim ↩︎

  • 128
    点赞
  • 336
    收藏
    觉得还不错? 一键收藏
  • 222
    评论
评论 222
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值