windows安装jax和jaxlib的教程

本文你将解决3个问题:1、jaxlib没有安装的问题;2、python3.9以上(不可忽略)、cuda12.1(可忽略)以上配置要求不满足的问题;3、numpy版本太高的问题。

1、问题描述

当你直接pip install jax或者conda install jax后,执行以下代码检查是否错误:

import jax
print(jax.devices())  # 应输出类似 [gpu(id=0)]

总是会报错:ModuleNotFoundError: jax requires jaxlib to be installed. See https://github.com/google/jax#installation for installation instructions.

在这里插入图片描述

出现该问题的原因是没有安装jaxlib。jaxlib只支持python3.9以上版本,且需要手动安装(直接用pip install jaxlib会报错)

ERROR: Could not find a version that satisfies the requirement jaxlib (from versions: none)
ERROR: No matching distribution found for jaxlib

2、解决办法

下面有2种情况,按照你的Windows电脑是否需要cuda来选择对应的教程。

  • 情况1,你不用cuda,那么只需要执行以下2步:

1、在虚拟环境中,在python3.9及以上的版本安装jax库,如 pip install jax 或者conda install jax,可以指定版本,这些就和一般的安装库那样。
2、下载jaxlib的文件,并手动安装。在https://storage.googleapis.com/jax-releases/jax_releases.html 地址中,键盘快捷键"ctrl + F"搜索"win" 找到对应python版本的jaxlib文件,jaxlib的版本自行测试吧。将其下载在本地任意文件夹中,然后像一般安装那样,在你的虚拟环境中安装此文件。

在这里插入图片描述

  • 情况2,已经配置了一个cuda11(或者以下的版本;如果你是cuda12及以上的版本,同样按照下面第2个步骤执行),那么只需要执行以下2步:

1、先安装cuda12(12.1以上的版本,必要的操作,不能跳过;无需卸载之前的cuda版本,多个版本的cuda可以共存),具体教程见以下两个教程(如果链接失效,请到我的csdn主页查找同名教程):
a. cuda 安装两个版本 https://blog.csdn.net/AdamCY888/article/details/147516608
b. 驱动支持的最高CUDA版本与实际安装的Runtime版本 https://blog.csdn.net/AdamCY888/article/details/147516543


在这里插入图片描述


(截图来自jax教程:https://jax.net.cn/en/latest/installation.html#installation

2、上面步骤1确保你已经有一个12.1以上版本的cuda。

a. 下载jax:pip install -U "jax[cuda12]", 注意,引号不能省略,且建议不指定其jax版本。
b. 接下来同前面情况1的步骤2一样,下载jaxlibwhl文件。自行对应相应的版本。

在这里插入图片描述

3、测试jax对应jaxlib的版本

由于并没有找到jax对应jaxlib的版本,于是就安装一个最低版本的jaxlib 0.4.13,按照其报错提示,来得到满足的版本。正确的对应关系是:jax 0.4.21 对应的 jaxlib 0.4.19;如果安装的其它版本,也可以通过这个方法来解决。

RuntimeError: jaxlib is version 0.4.13, but this version of jax requires version >= 0.4.19.

在这里插入图片描述
于是,重新在 https://storage.googleapis.com/jax-releases/jax_releases.html 下载"jaxlib 0.4.19",并安装。

在这里插入图片描述

接下来进一步测试以下程序:

import jax.numpy as jnp
def selu(x, alpha=1.67, lmbda=1.05):
  return lmbda * jnp.where(x > 0, x, alpha * jnp.exp(x) - alpha)

x = jnp.arange(5.0)
print(selu(x))

报错:

A module that was compiled using NumPy 1.x cannot be run in
NumPy 2.2.5 as it may crash. To support both 1.x and 2.x
versions of NumPy, modules must be compiled with NumPy 2.0.
Some module may need to rebuild instead e.g. with 'pybind11>=2.12'.

If you are a user of the module, the easiest solution will be to
downgrade to 'numpy<2' or try to upgrade the affected module.
We expect that some modules will need time to support NumPy 2.

Traceback (most recent call last):  File "d:\Anaconda\envs\jax_cuda12\lib\runpy.py", line 196, in _run_module_as_main
    return _run_code(code, main_globals, None,
  File "d:\Anaconda\envs\jax_cuda12\lib\runpy.py", line 86, in _run_code
    exec(code, run_globals)
 ...

报错的原因是NumPy版本太高,需要降低版本。执行以下代码即可解决:

# 在虚拟环境中执行
conda activate jax_cuda12
pip uninstall numpy -y
pip install numpy==1.24.4  # 选择广泛兼容的1.x版本

4、安装成功!

import jax
print(jax.devices())  # 应输出类似 [gpu(id=0)]

import jax.numpy as jnp

在这里插入图片描述

那么,接下来,请享受你的加速计算吧。

import jax.numpy as jnp
def selu(x, alpha=1.67, lmbda=1.05):
  return lmbda * jnp.where(x > 0, x, alpha * jnp.exp(x) - alpha)

x = jnp.arange(5.0)
print(selu(x))

在这里插入图片描述

联系我

如果你在Windows系统下安装jax过程中,有任何困难,请留言或者私信,我将定期回复。

考虑到您需要在Windows 10系统上,配合Python 3.7环境安装JAX库及其依赖库jaxlib,这里提供一个详细的安装指南,包括解决方案,以确保您可以顺利完成安装。首先,确保您已经安装了Python 3.7,并且已经升级到了最新版本的pip工具。接下来,根据《在Windows10 Python3.7中安装jaxjaxlib方法指南》中的指导,您需要下载适合您系统配置的whl文件进行安装。由于JAXjaxlib都是需要预先编译的库,下载对应版本的whl文件是避免编译错误的有效途径。以下是一步步安装的详细过程: 参考资源链接:[在Windows10 Python3.7中安装jaxjaxlib方法指南](https://wenku.csdn.net/doc/7g33rpgrkx?spm=1055.2569.3001.10343) 1. 首先打开命令提示符(cmd)或PowerShell,并切换到包含下载whl文件的目录。 2. 使用pip安装jax库的whl文件: ```bash pip install jax-0.2.9_and_jaxlib-0.1.61-cp37-win_amd64.whl ``` 3. 安装完成后,验证安装是否成功,可以在Python交互式环境输入以下代码: ```python import jax ``` 如果没有错误信息输出,那么说明JAX已经成功安装。 在安装过程中,可能会遇到一些常见问题。例如,如果系统提示找不到文件,那么请确保您下载的whl文件名完全正确,并且位于命令行指定的目录中。如果遇到版本不兼容的问题,检查下载的whl文件是否与您的系统架构(32位或64位)以及Python版本相匹配。此外,如果需要额外的依赖库或编译工具,如Visual Studio C++ Build Tools,那么请确保按照JAX官方文档进行安装配置。 安装完成后,JAX可以被用于进行高性能的机器学习深度学习计算。如果需要进一步学习如何使用JAX进行深度学习项目,建议查看JAX官方文档以及社区提供的教程项目案例。《在Windows10 Python3.7中安装jaxjaxlib方法指南》提供了基础安装的步骤注意事项,而对于更深入的学习应用,可以考虑查找更多的教程资源,例如在线课程、技术博客或者参加相关的技术交流活动。 参考资源链接:[在Windows10 Python3.7中安装jaxjaxlib方法指南](https://wenku.csdn.net/doc/7g33rpgrkx?spm=1055.2569.3001.10343)
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

高山莫衣

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值