环境:Python3.10
jax版本要求:
jax>=0.4.16.0,<=0.4.25
jaxlib>=0.4.16.0,<=0.4.25
cuda和cudnn版本:12.4+8.9.7
从4.17到4.25试了个遍都不好使,参考博客【Jax报错】CUDA backend failed to initialize: Unable to load cuPTI也不管用
之前装过一次jax可以用,对比了一下pip包,发现这次主要是由于我的cuda符合要求,在安装jax时候就没装NVIDIA 的 CUDA兼容包,就是下面这些包:
nvidia-cublas-cu12 12.1.3.1 pypi_0 pypi
nvidia-cuda-cupti-cu12 12.1<