Ubuntu下jax安装与使用

1 篇文章 0 订阅

目录

安装说明

pip安装

conda安装

参考网址


注:该项目目前仍然没有官方的Windows支持,需要自己编译。

安装说明

该库安装时分为两部分:

  1. jaxlib,该库平台相关,目前没有官方的编译
  2. jax,该库依赖jaxlib,平台无关,可以直接安装。

找到目前一个还活跃的jaxlib非官方编译服务:

https://github.com/cloudhan/jax-windows-builder

pip安装

要安装仅 CPU 版本的 JAX,这可能对在笔记本电脑上进行本地开发很有用,您可以运行

pip install --upgrade pip
pip install --upgrade "jax[cpu]"

在 Linux 上,通常需要先更新pip到支持 manylinux2014轮子的版本。 这些pip安装不适用于 Windows,并且可能会静默失败;见 上文

如果要安装同时支持 CPU 和 NVidia GPU 的 JAX,则必须首先安装CUDA和 CuDNN(如果尚未安装)。与其他一些流行的深度学习系统不同,JAX 没有将 CUDA 或 CuDNN 捆绑为pip 软件包的一部分。

JAX仅为 Linux提供预构建的 CUDA 兼容轮子,带有 CUDA 11.1 或更高版本,以及 CuDNN 8.0.5 或更高版本。操作系统、CUDA 和 CuDNN 的其他组合是可能的,但需要从源代码构建

  • 需要CUDA 11.1 或更新版本。
  • 预建轮子支持的 cuDNN 版本是:
    • cuDNN 8.2 或更高版本。如果您的 cuDNN 安装足够新,我们建议使用 cuDNN 8.2 轮,因为它支持附加功能。
    • cuDNN 8.0.5 或更高版本。
  • 必须使用至少与您的 CUDA 工具包的相应驱动程序版本一样新的 NVidia 驱动程序版本。例如,如果您安装了 CUDA 11.4 update 4,则在 Linux 上必须使用 NVidia 驱动程序 470.82.01 或更新版本。这是一个严格的要求,因为 JAX 依赖于 JIT 编译代码;较旧的驱动程序可能会导致故障。
    • 如果您需要将较新的 CUDA 工具包与较旧的驱动程序一起使用,例如在无法轻松更新 NVidia 驱动程序的集群上,您可以使用 NVidia 为此目的提供的CUDA 前向兼容性包。

接下来,运行

pip install --upgrade pip
# Installs the wheel compatible with CUDA 11 and cuDNN 8.2 or newer.
# Note: wheels only available on linux.
pip install --upgrade "jax[cuda]" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html

这些pip安装不适用于 Windows,并且可能会静默失败;

jaxlib 版本必须与您要使用的现有 CUDA 安装的版本相对应。您可以为 jaxlib 显式指定特定的 CUDA 和 CuDNN 版本:

pip install --upgrade pip

# Installs the wheel compatible with Cuda >= 11.4 and cudnn >= 8.2
pip install "jax[cuda11_cudnn82]" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html

# Installs the wheel compatible with Cuda >= 11.1 and cudnn >= 8.0.5
pip install "jax[cuda11_cudnn805]" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html

具体版本
pip install --upgrade jax==0.3.15 jaxlib==0.3.15+cuda11.cudnn82 -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html

conda安装

有一个社区支持的 Conda 构建jax。要安装 using conda,只需运行

conda install jax -c conda-forge

要在具有 NVidia GPU 的机器上安装,请运行

conda install jax cuda-nvcc -c conda-forge -c nvidia

请注意cudatoolkitDistributed by conda-forgeis missing ptxas,这是 JAX 要求的。因此,您必须cuda-nvcc从频道安装软件包nvidia,或者在您的机器上单独安装 CUDA,以便ptxas 在您的路径中。上面的频道顺序很重要(conda-forge之前 nvidia)。我们正在努力简化这一点。

如果您想覆盖 JAX 使用的 CUDA 版本,或者在没有 GPU 的机器上安装 CUDA 构建,请按照 网站提示和技巧 部分中的说明进行操作conda-forge

参考网址

https://github.com/google/jax

  • 0
    点赞
  • 9
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
您可以在Ubuntu安装JAX,具体步骤如下: 1. 首先,确保您的操作系统已经安装了Python和pip。 2. 如果您想安装支持CUDA 12的JAX,您可以运行以下命令: ``` pip install --upgrade "jax[cuda12_pip]" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html ``` 这将安装JAX和与CUDA 12兼容的扩展。 3. 如果您想安装支持CUDA 11的JAX,您可以运行以下命令: ``` pip install --upgrade "jax[cuda11_pip]" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html ``` 这将安装JAX和与CUDA 11兼容的扩展。 4. 如果您只需要安装CPU版本的JAX,您可以运行以下命令: ``` pip install --upgrade pip pip install --upgrade "jax<span class="em">1</span><span class="em">2</span><span class="em">3</span> #### 引用[.reference_title] - *1* *2* [Ubuntu22.04安装Whisper-jax](https://blog.csdn.net/qq_43907505/article/details/130866691)[target="_blank" data-report-click={"spm":"1018.2226.3001.9630","extra":{"utm_source":"vip_chatgpt_common_search_pc_result","utm_medium":"distribute.pc_search_result.none-task-cask-2~all~insert_cask~default-1-null.142^v92^chatsearchT3_1"}}] [.reference_item style="max-width: 50%"] - *3* [Ubuntujax安装使用](https://blog.csdn.net/zaf0516/article/details/126390534)[target="_blank" data-report-click={"spm":"1018.2226.3001.9630","extra":{"utm_source":"vip_chatgpt_common_search_pc_result","utm_medium":"distribute.pc_search_result.none-task-cask-2~all~insert_cask~default-1-null.142^v92^chatsearchT3_1"}}] [.reference_item style="max-width: 50%"] [ .reference_list ]
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值