JAX安装过程中遇到的各种坑

什么是JAX

JAX是谷歌于2018年推出的一个计算框架。目的是为了加速各类科学计算。
JAX官方网站:JAX github官网
JAX的潜力是巨大的,JAX在某些情况下的计算速度上可以说是远超现有的已经非常成熟的numpy,如下图所示:
Alt 在这里插入图片描述
图片来源:图片出处
并且由于JAX对于TPU具有良好的支持,随着TPU在深度学习领域的地位越来越重要,在未来JAX有可能会成为主流的深度学习计算框架。

安装JAX

下面介绍如何安装JAX,以及我在安装JAX过程中遇到的坑。(注意,目前JAX还不支持在Windows系统下使用GPU,本文将聚焦如何在Linux系统下安装JAX的GPU版本)

安装JAX遇到的坑

根据官方教程下的指示安装JAX。
由于JAX的版本与CUDA,cudnn的版本高度相关,因此我们在安装JAX之前需要先确定我们的CUDA版本和cudnn版本。(这里默认大家已经安装好了CUDA和cudnn,本篇博客不再赘述如何安装CUDA和cudnn)

坑1:查看CUDA和cudnn版本

查看CUDA版本

网上大部分博客关于如何在Linux系统下查看CUDA版本给的命令是

nvidia-smi

然后系统会在输出的右上角显示CUDA版本,如下图所示:
在这里插入图片描述
因此我们此时可能会以为我们的CUDA版本是12.0,但其实不是,具体原因可以参考这篇博客
按照上述博客的方法我们可以输入以下命令来查看实际的CUDA版本

nvcc -version

得到我们的CUDA版本是11.3

查看cudnn版本

在路径‘“/usr/local/cuda/targets/x86_64-linux/include”下的cudnn_version.h文件中找到类似
”CUDNN_MAJOR 8
CUDNN_MINOR 2”
在这里插入图片描述

的内容。如本例中所示,就说明我的cudnn版本是8.2

坑2:安装JAX的版本问题

现在我们已经知道了我们的CUDA是11.3,cudnn是8.2(这只是我环境中的示例,具体CUDA和cudnn一定要以你们自己的的环境中为准)
然后我们可以在官方教程中找到安装方式,如下图所示:
在这里插入图片描述
由于我们之前看到的CUDA是11.3,cudnn是8.2,所以我们自然会选择第二个命令对应cuda11。

pip install --upgrade pip

# CUDA 12 installation
# Note: wheels only available on linux.
pip install --upgrade "jax[cuda12_pip]" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html

# CUDA 11 installation
# Note: wheels only available on linux.
pip install --upgrade "jax[cuda11_pip]" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html

然而其实这里有一个坑,就是你实际pip安装的JAX版本所对应的cudnn或许并不是和你的cudnn版本一致,举个例子,在下图我运行了这个代码,然而实际安装的包是对应的cudnn8.6的版本。
在这里插入图片描述
那我们应该如何安装和我们的CUDA和cudnn版本对应的JAX呢?
其实很简单,使用如下命令可以获得一个报错信息,报错信息给出了可用的所有版本的信息,在这里我们选择一个“0.4.6+cuda11.cudnn82”的版本安装即可。

pip install --upgrade jax==0.4.6 jaxlib==-0.4.6+cuda11.cudnn82 -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html

在这里插入图片描述
然后我们输入以下命令安装即可,注意jax和jaxlib的版本需要一致,否则可能会报错。

 pip install --upgrade jax==0.4.6 jaxlib==0.4.6+cuda11.cudnn82 -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html

坑3:关于ptxas的报错,以及更新PATH路径

现在我们已经安装了与我们机器CUDA和cudnn版本一致的jax和jaxlib,然后当我们信心满满地运行我们的代码的时候却发现了以下报错:

W external/org_tensorflow/tensorflow/stream_executor/gpu/asm_compiler.cc:111] *** WARNING *** You are using ptxas 10.0.145, which is older than 11.8. ptxas before 11.1 is known to miscompile XLA code, leading to incorrect results or invalid-address errors.

You may not need to update to CUDA 11.8; cherry-picking the ptxas binary is often sufficient.
2023-02-15 09:07:00.421161: W external/org_tensorflow/tensorflow/compiler/xla/stream_executor/gpu/asm_compiler.cc:231] Falling back to the CUDA driver for PTX compilation; ptxas does not support CC 9.0
2023-02-15 09:07:00.421186: W external/org_tensorflow/tensorflow/compiler/xla/stream_executor/gpu/asm_compiler.cc:234] Used ptxas at ptxas
2023-02-15 09:07:00.520708: E external/org_tensorflow/tensorflow/compiler/xla/stream_executor/cuda/cuda_driver.cc:538] failed to load PTX text as a module: CUDA_ERROR_INVALID_IMAGE: device kernel image is invalid
2023-02-15 09:07:00.520773: E external/org_tensorflow/tensorflow/compiler/xla/stream_executor/cuda/cuda_driver.cc:543] error log buffer (63 bytes): error   : Binary format for key='0', ident='' is not recognize
2023-02-15 09:07:00.520821: E external/org_tensorflow/tensorflow/compiler/xla/pjrt/pjrt_stream_executor_client.cc:2410] Execution of replica 0 failed: INTERNAL: Failed to execute XLA Runtime executable: run time error: custom call 'xla.gpu.func.launch' failed: Failed to load PTX text as a module: CUDA_ERROR_INVALID_IMAGE: device kernel image is invalid.

出现以上的报错的原因在于CUDA或者cudnn版本过低,在第二个报错的第一行中指出”ptxas does not support CC 9.0“,我们需要根据这个图来决定CC9.0对应的CUDA版本是多少,图中给的是11.8和12.0-12.3。(图片出处
在这里插入图片描述

因此我们需要将CUDA版本升级到11.8。(CUDA版本升级cudnn一般也需要跟着一起升级,在这里我们安装的是CUDA11.8,cudnn8.6)

在安装完CUDA11.8和cudnn8.6之后我们需要用新的CUDA路径去更新PATH路径,这是因为jax在运行的时候会使用PATH路径来寻找ptxas,因此如果我们更新了CUDA和cudnn但没有更新PATH,对于jax来说它还是不知道新的ptxas在哪里
首先我们需要找到CUDA下的bin文件夹,确定里面有ptxas,如下图所示:
在这里插入图片描述

然后在shell窗口中使用如下命令,将新路径添加到PATH的开头。(因为系统在读取PATH路径是从前往后读取的,所以需要添加到最前面这样才不会被之前的路径给覆盖掉)

export PATH=/new/path:$PATH

其中“/new/path”是你想要添加的路径,在这里我的路径需要改为“/usr/local/cuda-11.8/bin“,具体需要什么路径以自己的实际系统中的为准。

export PATH=/usr/local/cuda-11.8/bin:$PATH

小结

第一次安装JAX,遇到了很多问题,来来回回折腾了两天。
总的来说就是需要在安装JAX的时候考虑到自己的CUDA和cudnn版本的问题,同时自己服务器上的CUDA和cudnn版本太低的问题。
之所以折腾了这么久其中有一部分原因是因为JAX这个项目目前还不算是很完善,官方的文档给的不够详细,还有一部分原因是自己的pip,cuda以及linux相关的知识掌握得不够牢固。
这篇博客既是回顾了一下自己安装jax时遇到的问题,踩过的坑以及解决的办法,同时也希望可以给在安装JAX时遇到困难的人一些启发吧。

  • 35
    点赞
  • 44
    收藏
    觉得还不错? 一键收藏
  • 2
    评论
评论 2
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值