要复现的工作需要用jax>=0.2.8版本,实际安装中遇到一些困难。现在将具体的版本记录在这里供大家参考。
linux 系统下安装 jax
还是建议大家在linux 下使用jax , 因为windows 还没有支持GPU
1. conda 创建一个 python=3.9的环境(这里用到了anaconda, 不了解的小伙伴可以找一下教程)
conda create -n py39jax python=3.9
conda activate py39jax
2. 安装 cudatoolkit 和 cudnn
这里我偷懒直接用了pytorch 官网的命令(直接装应该也是没问题的,但是我的源懒得换了)
conda install pytorch==1.8.0 torchvision==0.9.0 torchaudio==0.8.0 cudatoolkit=11.1 -c pytorch -c conda-forge
其实我不太确定没有cudnn 8.0.4 能不能行 :(, 先装上了,大家试过可以给一些反馈
nvidia :: Anaconda.org (如果你想找的包找不到,也可以在这里直接搜)
conda install nvidia::cudnn
3. 安装jax jaxlib
pip install jax==0.2.11
pip install jaxlib==0.1.72+cuda111 -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html
在这里的另一个tips 是一定要学会在官网找信息Installing JAX — JAX documentation
4. 其他的配套工具
pip install chex==0.1.2
pip install dm-haiku==0.0.5
pip install optax==0.0.3
5. 也许还会遇到一些numpy scipy 版本问题,大家按照提示修改就可以了。
windows 安装 jax
步骤和上面差不多,jaxlib package 地址要换一下。
https://whls.blob.core.windows.net/unstable/index.html
最后附上conda list (这个是windows 版本的)
# Name Version Build Channel
absl-py 2.1.0 pypi_0 pypi
ca-certificates 2024.3.11 haa95532_0 https://mirrors.tuna.tsinghua.edu.cn/anaconda/pkgs/main
chex 0.1.2 pypi_0 pypi
colorama 0.4.6 pypi_0 pypi
cuda-nvcc 12.4.131 0 nvidia
cudatoolkit 11.1.1 heb2d755_10 https://mirrors.tuna.tsinghua.edu.cn/anaconda/cloud/conda-forge
cudnn 8.2.1 cuda11.3_0 https://mirrors.tuna.tsinghua.edu.cn/anaconda/pkgs/main
dm-haiku 0.0.5 pypi_0 pypi
dm-tree 0.1.8 pypi_0 pypi
flatbuffers 2.0.7 pypi_0 pypi
importlib-metadata 7.1.0 pypi_0 pypi
jax 0.2.11 pypi_0 pypi
jaxlib 0.1.72 pypi_0 pypi
jmp 0.0.4 pypi_0 pypi
joblib 1.4.2 pypi_0 pypi
ml-dtypes 0.4.0 pypi_0 pypi
numpy 1.23.5 pypi_0 pypi
openssl 1.1.1w h2bbff1b_0 https://mirrors.tuna.tsinghua.edu.cn/anaconda/pkgs/main
opt-einsum 3.3.0 pypi_0 pypi
optax 0.0.3 pypi_0 pypi
pip 23.3.1 py39haa95532_0 https://mirrors.tuna.tsinghua.edu.cn/anaconda/pkgs/main
python 3.9.0 h6244533_2 https://mirrors.tuna.tsinghua.edu.cn/anaconda/pkgs/main
scikit-learn 1.4.2 pypi_0 pypi
scipy 1.12.0 pypi_0 pypi
setuptools 69.5.1 py39haa95532_0 https://mirrors.tuna.tsinghua.edu.cn/anaconda/pkgs/main
sqlite 3.45.3 h2bbff1b_0 https://mirrors.tuna.tsinghua.edu.cn/anaconda/pkgs/main
tabulate 0.9.0 pypi_0 pypi
threadpoolctl 3.5.0 pypi_0 pypi
toolz 0.12.1 pypi_0 pypi
tqdm 4.66.4 pypi_0 pypi
typing-extensions 4.11.0 pypi_0 pypi
tzdata 2024a h04d1e81_0 https://mirrors.tuna.tsinghua.edu.cn/anaconda/pkgs/main
vc 14.2 h21ff451_1 https://mirrors.tuna.tsinghua.edu.cn/anaconda/pkgs/main
vs2015_runtime 14.27.29016 h5e58377_2 https://mirrors.tuna.tsinghua.edu.cn/anaconda/pkgs/main
wheel 0.43.0 py39haa95532_0 https://mirrors.tuna.tsinghua.edu.cn/anaconda/pkgs/main
zipp 3.18.1 pypi_0 pypi