ubuntu安装JAX GPU版本


首先cuda是可以向下兼容的,因此,在ubuntu22.0.4上我的cuda为12.0,cudnn为8.8;但依然可以安装jax==0.4.7 jaxlib==0.4.7+cuda11.cudnn82的版本。

0 安装前提

已经安装好cuda和对应的cudnn,以及anaconda。

NVIDIA GPU官网中,我们可以看到当前支持的jax所对应的cuda和cudnn版本:
在这里插入图片描述

# 创建一个新的conda虚拟环境
# 在终端输入命令
conda create -n jaxEnv python=x.x

注意自己安装的python版本,创建环境指令后面的版本号要与以安装的python版本一致

# 查询创建的conda环境
conda info -e
# 激活自己刚创建的conda环境
conda activate jaxEnv

1 安装JAX GPU版本

# 可以使用上图中官方的安装命令
pip install --upgrade pip

# CUDA 12 installation
# Note: wheels only available on linux.
pip install --upgrade "jax[cuda12_pip]" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html

# CUDA 11 installation
# Note: wheels only available on linux.
pip install --upgrade "jax[cuda11_pip]" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html
# 或者
pip install --upgrade pip

# Installs the wheel compatible with CUDA 12 and cuDNN 8.9 or newer.
# Note: wheels only available on linux.
pip install --upgrade "jax[cuda12_local]" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html

# Installs the wheel compatible with CUDA 11 and cuDNN 8.6 or newer.
# Note: wheels only available on linux.
pip install --upgrade "jax[cuda11_local]" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html

或者可以参考google.storage,自己查询符合自己安装条件的版本
在这里插入图片描述

# 我使用的是如下命令安装成功
pip install --upgrade jax==0.4.7 jaxlib==0.4.7+cuda11.cudnn82 -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html

如果已经有安装的jax,切记先卸载再安装对应版本即可

# 卸载jax和jaxlib的命令
pip uninstall jax jaxlib

3 测试

测试如下文件在所创建的jaxEnv环境中是否能够得到如下类似结果

# TestJAX.py
import jax.numpy as np
from jax import random
import time

rng = random.PRNGKey(0)
x = random.uniform(rng, [5000, 5000])
st = time.time()
try:
    y = np.matmul(x, x)
except Exception:
    print("error")
print(time.time() - st)
print(y)

运行结果(没有无法识别GPU/TPU等信息即配置成功):
在这里插入图片描述


参考文章:
JAX: 库安装和GPU使用,解决不能识别gpu问题

安装支持CUDA 12的pytorch教程

ubuntu 安装 jax jaxlib cpu 和 gpu 版本 以及 tensorflow tensorRT的安装_如何安装jax_Eloudy的博客-程序员宅基地

  • 2
    点赞
  • 7
    收藏
    觉得还不错? 一键收藏
  • 2
    评论
您可以在Ubuntu安装JAX,具体步骤如下: 1. 首先,确保您的操作系统已经安装了Python和pip。 2. 如果您想安装支持CUDA 12的JAX,您可以运行以下命令: ``` pip install --upgrade "jax[cuda12_pip]" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html ``` 这将安装JAX和与CUDA 12兼容的扩展。 3. 如果您想安装支持CUDA 11的JAX,您可以运行以下命令: ``` pip install --upgrade "jax[cuda11_pip]" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html ``` 这将安装JAX和与CUDA 11兼容的扩展。 4. 如果您只需要安装CPU版本JAX,您可以运行以下命令: ``` pip install --upgrade pip pip install --upgrade "jax<span class="em">1</span><span class="em">2</span><span class="em">3</span> #### 引用[.reference_title] - *1* *2* [Ubuntu22.04安装Whisper-jax](https://blog.csdn.net/qq_43907505/article/details/130866691)[target="_blank" data-report-click={"spm":"1018.2226.3001.9630","extra":{"utm_source":"vip_chatgpt_common_search_pc_result","utm_medium":"distribute.pc_search_result.none-task-cask-2~all~insert_cask~default-1-null.142^v92^chatsearchT3_1"}}] [.reference_item style="max-width: 50%"] - *3* [Ubuntujax安装与使用](https://blog.csdn.net/zaf0516/article/details/126390534)[target="_blank" data-report-click={"spm":"1018.2226.3001.9630","extra":{"utm_source":"vip_chatgpt_common_search_pc_result","utm_medium":"distribute.pc_search_result.none-task-cask-2~all~insert_cask~default-1-null.142^v92^chatsearchT3_1"}}] [.reference_item style="max-width: 50%"] [ .reference_list ]
评论 2
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值