近期复现别人代码时需要使用flax、jax、tensorflow,在安装过程中遇到了一些bug,记录在此。
Tips:
- 使用import jax.numpy as jnp; a = jnp.zeros([2,3]); print(a.device()) 可以确认Jax使用的是GPU。
- 使用
tf.config.list_physical_devices('GPU')
可以确认 TensorFlow 使用的是 GPU。 - 默认情况下,TensorFlow 会映射进程可见的所有 GPU(取决于 CUDA_VISIBLE_DEVICES)的几乎全部内存。这是为了减少内存碎片,更有效地利用设备上相对宝贵的 GPU 内存资源。可以将环境变量
TF_FORCE_GPU_ALLOW_GROWTH
设置为true来
开启内存增长。(参考https://www.tensorflow.org/guide/gpu#limiting_gpu_memory_growth) - 当第一个 JAX 操作运行时,JAX 将预分配 90% 当前可用的 GPU 内存。可以设置XLA_PYTHON_CLIENT_PREALLOCATE=false来禁用预