Ubuntu18.04,有GPU,jax安装后显示错误如图所示,无法识别GPU,使用升级后并不能解决
Cuda >= 11.8 and cudnn >= 8.6,采用如下方法重新安装 pip install "jax[cuda11_cudnn86]" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html
如果Cuda >= 11.4 and cudnn >= 8.2
pip install "jax[cuda11_cudnn82]" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html
查看当前系统cudnn版本方法:
vim /usr/local/cuda-11.6/include/cudnn_version.h
当前cudnn版本为8.4