安装旧版本 jax 0.2.11 jaxlib 0.1.72 参考教程

要复现的工作需要用jax>=0.2.8版本,实际安装中遇到一些困难。现在将具体的版本记录在这里供大家参考。

linux 系统下安装 jax

还是建议大家在linux 下使用jax , 因为windows 还没有支持GPU

1. conda  创建一个 python=3.9的环境(这里用到了anaconda, 不了解的小伙伴可以找一下教程)

conda create -n py39jax python=3.9
conda activate py39jax

2. 安装 cudatoolkit 和 cudnn

这里我偷懒直接用了pytorch 官网的命令(直接装应该也是没问题的,但是我的源懒得换了)

conda install pytorch==1.8.0 torchvision==0.9.0 torchaudio==0.8.0 cudatoolkit=11.1 -c pytorch -c conda-forge

其实我不太确定没有cudnn 8.0.4 能不能行 :(, 先装上了,大家试过可以给一些反馈

nvidia :: Anaconda.org (如果你想找的包找不到,也可以在这里直接搜)

conda install nvidia::cudnn

3. 安装jax jaxlib

pip install jax==0.2.11

pip install jaxlib==0.1.72+cuda111 -f https://storage.goog
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值