微信公众号:leetcode_algos_life,代码随想随记
小红书:412408155
CSDN:https://blog.csdn.net/woai8339?type=blog ,代码随想随记
GitHub: https://github.com/riverind
抖音【暂未开始,计划开始】:tian72530,代码随想随记
知乎【暂未开始,计划开始】:happy001
问题描述
在安装jax中,安装命令
pip install --upgrade "jax[cuda12_pip]" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html
出现报错,报错信息如下:
CUDA backend failed to initialize: Found cuBLAS version 120205, but JAX was built against version 120304, which is newer. The copy of cuBLAS that is installed must be at least as new as the version against which JAX was built. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)
在gpu机器上测试,却是显示cpu。
解决方案
问题原因是版本问题,解决方案是版本降级。
pip install --upgrade "jax[cuda12_pip]==0.4.20" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html
测试
测试代码如下:
from jax.lib import xla_bridge
print(xla_bridge.get_backend().platform)
再次运行,显示GPU,搞定。