Q1: 为什么使用jax,python当前火热的框架比如pytorch 和TensorFlow 已经很好用了?
A1: 这个jax的优势,可以去知乎上搜jax 会出现很多帖子
也可以去官网上面看官方解释
https://jax.readthedocs.io/en/latest/index.html
Q2: 那么如何安装jax呢?
A2: GPU version
pip install --upgrade jax==0.2.3 jaxlib==0.1.69+cuda111 -f https://storage.googleapis.com/jax-releases/jax_releases.html
CPU version
pip install jax
pip install jaxlib
Attention1 : 如果安装CPU version就不要期待教高的速度,有条件的还是装GPU version的吧!
Q3 : GPU 版本的如何指定显卡号?
A3 : 一开始我以为所有的python 框架尤其是新出来的框架都和pytorch一样框架有自身独特的指定显卡的卡号,但是jax 还真没有找到。而是使用对所有python代码都有用的指定方式:
import os
os.environ['CUDA_VISIBLE_DEVICES']='1'