GitHub Link:
https://github.com/google/jax/issues/5723
主要原因是cuda版本和gpu的driver版本不匹配。导致无法并行计算或者编译。解决方法有两种:
- 跳过并行编程
XLA_FLAGS=--xla_gpu_force_compilation_parallelism=1
- 安装合适的Cuda版本,查询链接:
https://docs.nvidia.com/deploy/cuda-compatibility/index.html
简单的来说,
CUDA Toolkit Linux x86_64 Minimum Required Driver Version Windows Minimum Required Driver Version
CUDA 12.x >=525.60.13 >=527.41
CUDA 11.x >= 450.80.02* >=452.39*
CUDA 10.2 >= 440.33 >=441.22
CUDA 10.1 >= 418.39 >=418.96
CUDA 10.0 >= 410.48 >=411.31
因此请检查对不对。当然也不一定,比如是450开头的driver还是用cuda10比较靠谱。