mamba环境安装记录

折腾了一下午,最终发现还是离线安装最靠谱,这是我的可以运行的环境版本:
 

PyTorch version: 2.1.1
CUDA version: 11.8
causal-conv1d version: 1.2.0.post2
mamba-ssm version: 2.1.0
CUDA toolkit version: Build cuda_11.8.r11.8/compiler.31833905_0

测试代码如下:

import torch
import torch.nn.functional as F
from causal_conv1d import causal_conv1d_fn
from mamba_ssm import Mamba, Mamba2

# 定义 causal_conv1d_fn 函数
def causal_conv1d_fn(x, weight, bias=None, activation=None):
    """
    x: (batch, dim, seqlen)
    weight: (dim, width)
    bias: (dim,)
    activation: either None or "silu" or "swish"

    out: (batch, dim, seqlen)
    """
    dim, width = weight.shape
    seqlen = x.shape[2]
    output = F.conv1d(x, weight.unsqueeze(1), bias, padding=width - 1, groups=dim)[..., :seqlen]
    if activation == "silu" or activation == "swish":
        output = output * torch.sigmoid(output)
    return output

# 测试 causal_conv1d_fn 函数
def test_causal_conv1d_fn():
    batch_size = 2
    dim = 3
    seqlen = 5
    width = 3

    # 随机生成输入数据
    x = torch.randn(batch_size, dim, seqlen)
    weight = torch.randn(dim, width)
    bias = torch.randn(dim)

    # 无激活函数测试
    output = causal_conv1d_fn(x, weight, bias)
    print("Output without activation:")
    print(output)

    # 使用 "silu" 激活函数测试
    output_silu = causal_conv1d_fn(x, weight, bias, activation="silu")
    print("Output with SILU activation:")
    print(output_silu)

    # 使用 "swish" 激活函数测试
    output_swish = causal_conv1d_fn(x, weight, bias, activation="swish")
    print("Output with Swish activation:")
    print(output_swish)

# 测试 Mamba 模型
def test_mamba_model():
    batch, length, dim = 2, 64, 16
    x = torch.randn(batch, length, dim).to("cuda")
    model = Mamba(
        d_model=dim,   # Model dimension d_model
        d_state=16,    # SSM state expansion factor
        d_conv=4,      # Local convolution width
        expand=2,      # Block expansion factor
    ).to("cuda")
    
    try:
        y = model(x)
        assert y.shape == x.shape
        print("Mamba model output shape:", y.shape)
        print(y)
    except TypeError as e:
        print(f"TypeError: {e}")

# 运行测试函数
def main():
    print("Testing causal_conv1d_fn:")
    test_causal_conv1d_fn()
    
    print("\nTesting Mamba model:")
    test_mamba_model()


if __name__ == "__main__":
    main()

正常输出代表安装成功

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值