ubuntu安装JAX GPU版本
首先cuda是可以向下兼容的,因此,在ubuntu22.0.4上我的cuda为12.0,cudnn为8.8;但依然可以安装jax==0.4.7 jaxlib==0.4.7+cuda11.cudnn82
的版本。
0 安装前提
已经安装好cuda和对应的cudnn,以及anaconda。
从NVIDIA GPU官网中,我们可以看到当前支持的jax所对应的cuda和cudnn版本:
# 创建一个新的conda虚拟环境
# 在终端输入命令
conda create -n jaxEnv python=x.x
注意自己安装的python版本,创建环境指令后面的版本号要与以安装的python版本一致
# 查询创建的conda环境
conda info -e
# 激活自己刚创建的conda环境
conda activate jaxEnv
1 安装JAX GPU版本
# 可以使用上图中官方的安装命令
pip install --upgrade pip
# CUDA 12 installation
# Note: wheels only available on linux.
pip install --upgrade "jax[cuda12_pip]" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html
# CUDA 11 installation
# Note: wheels only available on linux.
pip install --upgrade "jax[cuda11_pip]" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html
# 或者
pip install --upgrade pip
# Installs the wheel compatible with CUDA 12 and cuDNN 8.9 or newer.
# Note: wheels only available on linux.
pip install --upgrade "jax[cuda12_local]" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html
# Installs the wheel compatible with CUDA 11 and cuDNN 8.6 or newer.
# Note: wheels only available on linux.
pip install --upgrade "jax[cuda11_local]" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html
或者可以参考google.storage,自己查询符合自己安装条件的版本
# 我使用的是如下命令安装成功
pip install --upgrade jax==0.4.7 jaxlib==0.4.7+cuda11.cudnn82 -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html
如果已经有安装的jax,切记先卸载再安装对应版本即可
# 卸载jax和jaxlib的命令
pip uninstall jax jaxlib
3 测试
测试如下文件在所创建的jaxEnv
环境中是否能够得到如下类似结果
# TestJAX.py
import jax.numpy as np
from jax import random
import time
rng = random.PRNGKey(0)
x = random.uniform(rng, [5000, 5000])
st = time.time()
try:
y = np.matmul(x, x)
except Exception:
print("error")
print(time.time() - st)
print(y)
运行结果(没有无法识别GPU/TPU等信息即配置成功):
参考文章:
JAX: 库安装和GPU使用,解决不能识别gpu问题
ubuntu 安装 jax jaxlib cpu 和 gpu 版本 以及 tensorflow tensorRT的安装_如何安装jax_Eloudy的博客-程序员宅基地