折腾了一下午,最终发现还是离线安装最靠谱,这是我的可以运行的环境版本:
PyTorch version: 2.1.1
CUDA version: 11.8
causal-conv1d version: 1.2.0.post2
mamba-ssm version: 2.1.0
CUDA toolkit version: Build cuda_11.8.r11.8/compiler.31833905_0
测试代码如下:
import torch
import torch.nn.functional as F
from causal_conv1d import causal_conv1d_fn
from mamba_ssm import Mamba, Mamba2
# 定义 causal_conv1d_fn 函数
def causal_conv1d_fn(x, weight, bias=None, activation=None):
"""
x: (batch, dim, seqlen)
weight: (dim, width)
bias: (dim,)
activation: either None or "silu" or "swish"
out: (batch, dim, seqlen)
"""
dim, width = weight.shape
seqlen = x.shape[2]
output = F.conv1d(x, weight.unsqueeze(1), bias, padding=width - 1, groups=dim)[..., :seqlen]
if activation == "silu" or activation == "swish":
output = output * torch.sigmoid(output)
return output
# 测试 causal_conv1d_fn 函数
def test_causal_conv1d_fn():
batch_size = 2
dim = 3
seqlen = 5
width = 3
# 随机生成输入数据
x = torch.randn(batch_size, dim, seqlen)
weight = torch.randn(dim, width)
bias = torch.randn(dim)
# 无激活函数测试
output = causal_conv1d_fn(x, weight, bias)
print("Output without activation:")
print(output)
# 使用 "silu" 激活函数测试
output_silu = causal_conv1d_fn(x, weight, bias, activation="silu")
print("Output with SILU activation:")
print(output_silu)
# 使用 "swish" 激活函数测试
output_swish = causal_conv1d_fn(x, weight, bias, activation="swish")
print("Output with Swish activation:")
print(output_swish)
# 测试 Mamba 模型
def test_mamba_model():
batch, length, dim = 2, 64, 16
x = torch.randn(batch, length, dim).to("cuda")
model = Mamba(
d_model=dim, # Model dimension d_model
d_state=16, # SSM state expansion factor
d_conv=4, # Local convolution width
expand=2, # Block expansion factor
).to("cuda")
try:
y = model(x)
assert y.shape == x.shape
print("Mamba model output shape:", y.shape)
print(y)
except TypeError as e:
print(f"TypeError: {e}")
# 运行测试函数
def main():
print("Testing causal_conv1d_fn:")
test_causal_conv1d_fn()
print("\nTesting Mamba model:")
test_mamba_model()
if __name__ == "__main__":
main()
正常输出代表安装成功