问题:
使用jax时无法检测到本地GPU:
import jax
print(jax.devices())
>>>CUDA backend failed to initialize: Unable to load cuDNN. Is it installed? (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)
[CpuDevice(id=0)]
解决方案:
我跟随的解决方案是下面这个链接,但是和他的问题不一样:
JAX: 库安装和GPU使用,解决不能识别gpu问题
跟随他的步骤时发现jax又好了:
import torch,jax
print(jax.devices())
>>>[cuda(id=0), cuda(id=1)]
print(jax.local_devices())
>>>[cuda(id=0), cuda(id=1)]
那就只能是因为torch了。用编译器查询一下哪里有
import jax
改成
import torch, jax
大概因为当初环境里的cuda相关的cuda-enabled libs是随着torch一起装的,导致使用时需要把torch引入一下才能启动。平时有关模型的代码,torch都是默认import的,这里没用到就出问题了。