要复现的工作需要用jax>=0.2.8版本,实际安装中遇到一些困难。现在将具体的版本记录在这里供大家参考。
linux 系统下安装 jax
还是建议大家在linux 下使用jax , 因为windows 还没有支持GPU
1. conda 创建一个 python=3.9的环境(这里用到了anaconda, 不了解的小伙伴可以找一下教程)
conda create -n py39jax python=3.9
conda activate py39jax
2. 安装 cudatoolkit 和 cudnn
这里我偷懒直接用了pytorch 官网的命令(直接装应该也是没问题的,但是我的源懒得换了)
conda install pytorch==1.8.0 torchvision==0.9.0 torchaudio==0.8.0 cudatoolkit=11.1 -c pytorch -c conda-forge
其实我不太确定没有cudnn 8.0.4 能不能行 :(, 先装上了,大家试过可以给一些反馈
nvidia :: Anaconda.org (如果你想找的包找不到,也可以在这里直接搜)
conda install nvidia::cudnn
3. 安装jax jaxlib
pip install jax==0.2.11
pip install jaxlib==0.1.72+cuda111 -f https://storage.goog