源码编译causal-conv1d
⚠️ 关键背景——为什么官方 wheel 用不了
症状 | 根因 |
---|---|
UserWarning: … sm_120 is not compatible with the current PyTorch installation 或者运行时报 no kernel image is available for execution on the device | 截至 2025-05-04,PyTorch 2.3.x 官方二进制只编到 sm_90;5090(Blackwell 消费级)是 Compute Capability 12.0 (= sm_120),版本判断直接挡下来了(GitHub, GitHub) |
即便你自己 export TORCH_CUDA_ARCH_LIST="12.0" ,pip install causal-conv1d 仍旧失败 | causal-conv1d 的 setup.py 写死了 “只允许 ≤ sm_90” 的白名单,导致编译前就被 assert 掉 |
结论: 想跑 5090,必须让 PyTorch 和 causal-conv1d 两层都认识 sm_120;最简单办法就是 全部源码编译一次。好在 CUDA 12.8+ 的 nvcc
已经内置 sm_120
开关(NVIDIA Developer Forums)。
0. 环境前提
-
驱动 / CUDA Toolkit ≥ 12.8
nvcc --version # release 12.8.xx 或 12.9.xx nvidia-smi | head -n1 # driver ≥ 555.xx
-
Python & gcc:建议 Python 3.10+/3.11,gcc >= 11(Blackwell 编译器路径里有些 C++20 flag)。
-
建议单独 conda/venv,别和旧 CUDA 共存。
1. 让 PyTorch 支持 sm_120
方案 A — 直接用 nightly cu129(最快)
pip install --pre torch torchvision torchaudio --index-url https://download.pytorch.org/whl/nightly/cu129
python - <<'PY'
import torch, json, os
print("CUDA", torch.version.cuda, "arch list:", torch.cuda.get_arch_list())
PY
# 看到 'sm_120' 即 OK
2025-03-15 之后的 nightly 已合入 [Add SM120 support] PR,且默认编译 PTX/Fatbin 到 12.0。
如果你想保守一点,可以走方案 B 手工编译稳定分支。
方案 B — 源码编 PyTorch(保证可控)
git clone --recursive https://github.com/pytorch/pytorch
cd pytorch
git checkout v2.3.1 # 或 main 分支最新
export USE_CUDA=1
export TORCH_CUDA_ARCH_LIST="8.0;8.6;8.9;9.0;12.0"
export MAX_JOBS=$(nproc)
python setup.py bdist_wheel
pip install dist/torch-*.whl
编完再跑上一段 torch.cuda.get_arch_list()
做 sanity check。
2. Patch causal-conv1d
让它接受 sm_120
官方仓库白名单只到
'90'
,需要手动改一行。
git clone https://github.com/Dao-AILab/causal-conv1d.git
cd causal-conv1d
# ★ 只改这一行即可
sed -i 's/\"90\"/\"90\",\"120\"/' setup.py
(如果怕 sed 位置对不准,手动打开 setup.py
搜 VALID_ARCHES
就能看到列表。)
3. 编译安装 causal-conv1d
# 让 nvcc 只编 12.0,可加 +ptx 生成 PTX 备用
export TORCH_CUDA_ARCH_LIST="12.0;12.0+ptx"
# CMake 体系再双保险
export CMAKE_ARGS="-DCMAKE_CUDA_ARCHITECTURES=120"
# 一键编译 + 装 wheel
pip install --no-build-isolation --no-binary causal-conv1d -v .
编译时间比 sm_9X 长很多(Blackwell 编译器寄存器分配更激进,见 NV forum 讨论)——耐心等。成功后:
import torch, causal_conv1d, os
print(torch.cuda.get_device_name(), causal_conv1d.__version__)
4. 运行自检
from causal_conv1d import causal_conv1d_fn
x = torch.randn(2, 64, 512, device='cuda', dtype=torch.float16)
w = torch.randn(64, 4, device='cuda', dtype=torch.float16)
y = causal_conv1d_fn(x, w)
print(y.shape, y.is_cuda) # 期望 (2, 64, 512) True
5. 常见疑难杂症
现象 | 处理 |
---|---|
nvcc fatal: unrecognized option 'compute_120' | CUDA Toolkit < 12.8;升级 Toolkit |
sm_120 相关编译极慢 / 寄存器溢出 | 这是 NVCC for Blackwell 已知现象;可以先编 compute_120,code=compute_120 只留 PTX,让运行时 JIT,性能略低但能跑 |
运行还报 no kernel image… | 检查 torch.version.cuda 和 nvcc 是否同属 12.x,以及 torch.cuda.get_arch_list() 是否包含 'sm_120' |
Flash-Attention / bitsandbytes 等其它 CUDA 扩展也炸 | 按同样思路打开它们的 setup.py 白名单,或等官方合入 sm_120 PR (很多项目已有 issue 正在跟进)(GitHub) |
6. 性能与后续优化
-
Blackwell 的 shared-memory / register 行为跟 Hopper 不同;
causal-conv1d
目前 kernel 还是针对 Ampere/Hopper 的 schedule,能用但不一定最优。 -
如果后续官方 upstream 把 sm_120 加进来,记得把你手改的仓库换回官方 release,以免错过 kernel-level 优化。
-
也可以在本地开
nvprof
或Nsight Compute
对比 compute_120 PTX-JIT vs fatbin sm_120,据 NV forum 实测差距 ~10-15%,可按需求权衡。
TL;DR
# 5090 / sm_120 最短路径
pip install --pre torch torchvision torchaudio -i https://download.pytorch.org/whl/nightly/cu129
git clone https://github.com/Dao-AILab/causal-conv1d && cd causal-conv1d
sed -i 's/"90"/"90","120"/' setup.py
TORCH_CUDA_ARCH_LIST="12.0;12.0+ptx" pip install -v .
这样即可在 Linux 本地成功编译并运行 causal-conv1d
。有别的报错把日志贴出来,我们再逐条拆!
源码编译 mamba-ssm
下面给出 RTX 5090(Compute Capability 12.0 = sm_120) 环境下,从源码编译 mamba-ssm 的完整步骤。整体思路跟你刚才搞定 causal-conv1d
一样:
-
先准备能识别 sm_120 的 PyTorch;
-
手动打补丁,让
setup.py
把 sm_120 加进 nvcc 编译参数; -
强制本地编译,跳过旧 wheel;
-
运行快测。
0 前置:确保 PyTorch 已支持 sm_120
你刚才用的 nightly cu129 或自己源码编的 PyTorch 已经能 torch.cuda.get_arch_list()
➜ ['sm_120', …]
,就 OK 了。如果后来换过环境,先确认:
python - <<'PY'
import torch, subprocess, os
print('Torch:', torch.__version__, ' Built-for-CUDA:', torch.version.cuda)
print('nvcc :', subprocess.check_output(["nvcc","-V"], text=True).split()[-1])
print('arch :', torch.cuda.get_arch_list())
PY
1 克隆源码并切到最新版 tag
git clone https://github.com/state-spaces/mamba.git
cd mamba
git checkout v2.2.4 # 2024-12-06 最新正式版:contentReference[oaicite:0]{index=0}
说明:
mamba-ssm
PyPI 的 2.2.x wheel 还只编到 sm_90,因此必须自己编。
2 一键补丁:把 sm_120 写进 setup.py
setup.py
里有一段手动列出的 nvcc 架构白名单,最末尾停在 sm_90(GitHub)。直接用 sed
把 12.0 插进去即可:
# ① 在 compute_90 那行下面补一行 120
sed -i '/arch=compute_90/a\ cc_flag.append("-gencode"); cc_flag.append("arch=compute_120,code=sm_120")' setup.py
(想精简编译时间,可再把前面 53-87 这些老卡算力注释掉,只留 120。)
3 编译 & 安装
# 让 pip 不去找旧 wheel,强制本地编
export MAMBA_FORCE_BUILD=TRUE
# 缩小 nvcc PTX 范围,进一步加快
export TORCH_CUDA_ARCH_LIST="12.0;12.0+ptx"
# 可选:并行线程数
export MAX_JOBS=$(nproc)
# 真正开编
pip install --no-build-isolation -v .
-
如果你在国内、GitHub Release 下载慢,
MAMBA_FORCE_BUILD=TRUE
能避免无谓的 wheel 下载超时。 -
编译时间会比
causal-conv1d
长不少(大概 3-5 min / A100,5090 略慢);耐心等日志出现 “building wheel for mamba-ssm done”。
4 验证
python - <<'PY'
import torch, mamba_ssm
from mamba_ssm import Mamba
x = torch.randn(2, 1024, 256, device='cuda', dtype=torch.float16)
model = Mamba(d_model=256, d_state=16, d_conv=4, expand=2).cuda().half()
y = model(x)
print('ok', y.shape, 'arch', torch.cuda.get_device_properties(0).name)
PY
只要能跑出 ok torch.Size([2, 1024, 256]) arch NVIDIA RTX 5090
,并且 无 “no kernel image for sm_120” 就表示编译成功。
5 常见报错速查
日志片段 | 原因 | 处理 |
---|---|---|
cc_flag.append("arch=compute_120...") 仍未出现 | sed 没打中行 | 手动打开 setup.py 搜 compute_90 ,再把 120 行粘上 |
nvcc fatal: Unsupported gpu architecture 'compute_120' | CUDA Toolkit < 12.8 | sudo apt install cuda-12-9 或升级 runfile |
__nvJitLinkComplete_12_4… | Torch 用的 CUDA ≥ 12.3,但系统里 nvcc 还是 12.0-12.2 | 升级 Toolkit ≥ 12.3,或重装匹配的 PyTorch |
编译极慢 / OOM | Blackwell nvcc 寄存器溢出 | 先编 compute_120,code=compute_120 (只产 PTX),运行时 JIT,或用 -maxrregcount 调小寄存器 |
6 与其它 CUDA 扩展兼容小贴士
-
Flash-Attention-2、bitsandbytes 如遇同样 sm_120 限制,可以比照此法在各自
setup.py
增补 120。 -
mamba-ssm
自带的selective_scan_cuda
会随刚才补丁一起编好,不需额外操作。
最小可复现脚本(TL;DR)
# 假设已装 nightly/cu128 Torch,且 causal-conv1d 已编好
git clone https://github.com/state-spaces/mamba && cd mamba && git checkout v2.2.4
sed -i '/compute_90/a\ cc_flag.append("-gencode"); cc_flag.append("arch=compute_120,code=sm_120")' setup.py
MAMBA_FORCE_BUILD=TRUE TORCH_CUDA_ARCH_LIST="12.0;12.0+ptx" pip install -v .