目录
注:该项目目前仍然没有官方的Windows支持,需要自己编译。
安装说明
该库安装时分为两部分:
- jaxlib,该库平台相关,目前没有官方的编译
- jax,该库依赖jaxlib,平台无关,可以直接安装。
找到目前一个还活跃的jaxlib非官方编译服务:
https://github.com/cloudhan/jax-windows-builder
pip安装
要安装仅 CPU 版本的 JAX,这可能对在笔记本电脑上进行本地开发很有用,您可以运行
pip install --upgrade pip
pip install --upgrade "jax[cpu]"
在 Linux 上,通常需要先更新pip
到支持 manylinux2014
轮子的版本。 这些pip
安装不适用于 Windows,并且可能会静默失败;见 上文。
如果要安装同时支持 CPU 和 NVidia GPU 的 JAX,则必须首先安装CUDA和 CuDNN(如果尚未安装)。与其他一些流行的深度学习系统不同,JAX 没有将 CUDA 或 CuDNN 捆绑为pip
软件包的一部分。
JAX仅为 Linux提供预构建的 CUDA 兼容轮子,带有 CUDA 11.1 或更高版本,以及 CuDNN 8.0.5 或更高版本。操作系统、CUDA 和 CuDNN 的其他组合是可能的,但需要从源代码构建。
- 需要CUDA 11.1 或更新版本。
- 预建轮子支持的 cuDNN 版本是:
- cuDNN 8.2 或更高版本。如果您的 cuDNN 安装足够新,我们建议使用 cuDNN 8.2 轮,因为它支持附加功能。
- cuDNN 8.0.5 或更高版本。
- 您必须使用至少与您的 CUDA 工具包的相应驱动程序版本一样新的 NVidia 驱动程序版本。例如,如果您安装了 CUDA 11.4 update 4,则在 Linux 上必须使用 NVidia 驱动程序 470.82.01 或更新版本。这是一个严格的要求,因为 JAX 依赖于 JIT 编译代码;较旧的驱动程序可能会导致故障。
- 如果您需要将较新的 CUDA 工具包与较旧的驱动程序一起使用,例如在无法轻松更新 NVidia 驱动程序的集群上,您可以使用 NVidia 为此目的提供的CUDA 前向兼容性包。
接下来,运行
pip install --upgrade pip
# Installs the wheel compatible with CUDA 11 and cuDNN 8.2 or newer.
# Note: wheels only available on linux.
pip install --upgrade "jax[cuda]" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html
这些pip
安装不适用于 Windows,并且可能会静默失败;
jaxlib 版本必须与您要使用的现有 CUDA 安装的版本相对应。您可以为 jaxlib 显式指定特定的 CUDA 和 CuDNN 版本:
pip install --upgrade pip
# Installs the wheel compatible with Cuda >= 11.4 and cudnn >= 8.2
pip install "jax[cuda11_cudnn82]" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html
# Installs the wheel compatible with Cuda >= 11.1 and cudnn >= 8.0.5
pip install "jax[cuda11_cudnn805]" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html
具体版本
pip install --upgrade jax==0.3.15 jaxlib==0.3.15+cuda11.cudnn82 -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html
conda安装
有一个社区支持的 Conda 构建jax
。要安装 using conda
,只需运行
conda install jax -c conda-forge
要在具有 NVidia GPU 的机器上安装,请运行
conda install jax cuda-nvcc -c conda-forge -c nvidia
请注意cudatoolkit
Distributed by conda-forge
is missing ptxas
,这是 JAX 要求的。因此,您必须cuda-nvcc
从频道安装软件包nvidia
,或者在您的机器上单独安装 CUDA,以便ptxas
在您的路径中。上面的频道顺序很重要(conda-forge
之前 nvidia
)。我们正在努力简化这一点。
如果您想覆盖 JAX 使用的 CUDA 版本,或者在没有 GPU 的机器上安装 CUDA 构建,请按照 网站提示和技巧 部分中的说明进行操作conda-forge
。