Google最新开源机器学习框架,Github已超18万Stars!
前言
近年来深度学习框架市场依然呈现老牌两巨头TensorFlow和PyTorch的相爱相杀,不过由于TF这几年不断地更新而变得越来越臃肿难用,大家已经逐渐开始抛弃它(作者本人也已经转向了PyTorch,毕竟简单好用)。不过Google毕竟是大公司,肯定不能这么就随便让竞争对手的产品任意发展,因此也就有了今天本文所要介绍的最新框架JAX。
1、JAX是什么
JAX是Google于2018年开发并开源的科学计算库,其描述为Autograd和TensorFlow XLA的结合体。Autograd是pytorch计算反向传播时候的提供的一个组件,可以实现梯度的自动计算与更新。XLA(Accelerated Linear Algebra)是TensorFlow中内置的通过编译线性代数加速数据计算以提升机器学习模型速度的组件。可以这样理解,JAX就是一个在执行一些科学计算的时候,能够提供强大加速能力的工具。
官方文档:https://jax.readthedocs.io/en/latest/jax-101/index.html
项目地址:https://github.com/google/jax
2、 JAX能够用来做什么?
pip install --upgrade "jax[cpu]" # cpu版本
pip install --upgrade "jax[cuda]" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html # cuda版本
pip install "jax[tpu]>=0.2.16" -f https://storage.googleapis.com/jax-releases/libtpu_releases.html # tpu版本
具体安装指南可以在官方Github上面看https://github.com/google/jax#installation。
有一些例子可以让我们简单的了解一下JAX可以做些什么(我电脑没有gpu所以下载的是cpu版本的jax)。
先构建一个三次方程然后使用numpy来进行计算:
1. import numpy as np
2. import time
3.
4.
5. def function(x):
6. # 构建一个三次方程
7. return x * x * x + x * x + x
8.
9.
10. x = np.random.randn(10000, 10000)
11. start_time = time.time()
12. for i in range(5):
13. loop_start = time.time()
14. ans = function(x)
15. loop_stop = time.time()
16. print(f'Loop {i+1} costs: {loop_stop - loop_start: .3f}')
17. stop_time = time.time()
18. print(f'5 loops cost total {stop_time - start_time: .3f}')
1. import numpy as np
2. import time
3. import jax
4. import jax.numpy as jnp
5.
6.
7. def function(x):
8. # 构建一个三次方程
9. return x * x * x + x * x + x
10.
11.
12. # 用Jax对三次方程进行计算并计时
13. x = np.random.randn(10000, 10000)
14. jax_function = jax.jit(function)
15. x = jnp.array(x)
16. start_time = time.time()
17. for i in range(5):
18. loop_start = time.time()
19. ans = jax_function(x)
20. loop_stop = time.time()
21. print(f'Loop {i+1} costs: {loop_stop - loop_start: .3f}')
22. stop_time = time.time()
23. print(f'5 loops cost total {stop_time - start_time: .3f}')
在JAX的官方文档也里面举了很多的代码例子来进行说明:https://github.com
/google/jax#installation。
当然JAX不仅仅只是一个超级加速过后的Numpy库,它同时提供了一个扩展系统用来转换函数。JAX目前提供用来转换函数的方法:
1) grad():自动计算微分
2) jit():即时编译加速代码(示例在上面已经给出)
3) vmap():自动向量化
4) pmap():并行编程
grad示例:
1) import time
2) import autograd.numpy as np
3) from autograd import grad
4)
5)
6) def tanh(x):
7) y = np.exp(-2.0 * x)
8) return (1.0 - y) / (1.0 + y)
9)
10)# 使用Autograd来计算微分
11)start_time = time.time()
12)print(grad(grad(grad(tanh)))(1.0))
13)end_time = time.time()
14)print(f'Cost {end_time - start_time:.8f} s')
jax实现:
1) import time
2) import autograd.numpy as np
3) from autograd import grad
4)
5)
6) def tanh(x):
7) y = np.exp(-2.0 * x)
8) return (1.0 - y) / (1.0 + y)
9)
10)
11)# 使用Autograd来计算微分
12)start_time = time.time()
13)print(jit(grad(jit(grad(jit(grad(jit(tanh)))))))(1.0))
14)end_time = time.time()
15)print(f'Cost {end_time - start_time:.8f} s')
这里进行对比的时候不知道为什么Autograd和Jax的计算时间会相差如此之大,这里推测可能是Jax操作的时候会先找GPU,因为运行环境没有GPU因此Jax将函数转移到了CPU内存上进行了操作,可能是这一次I/O操作导致了较高的计算时间。
由于运行环境没有GPU,因此对ymap和pmap就不做测试和展示了,具体的代码示例在项目Readme: https://github.com/google/jax#transformations
总结一下:
除了以上一些最基本的API之外,JAX还提供了其他的一些用于科学计算的API,可以阅读https://jax.readthedocs.io/en/latest/来查看这些库的用法。整体上JAX还是以计算任务为主,矩阵相乘、求导之类的计算在机器学习或深度学习中很常见,因此使用JAX去加速这些计算能够最大限度的发挥硬件性能,最大限度的做性能优化。当然,JAX还没有完整的生态,如果要在一个完整的大型深度学习项目中使用JAX,仍然需要依赖其他的第三方库,并且需要使用者有很好的数学背景,毕竟如果要使用JAX来设计一个模型,则需要将整个计算图中的节点均以函数的形式一一实现用JAX来实现。在深度学习方面,一些基于JAX强大计算性能的高级库如Flax、Haiku等,可能在未来生态逐渐发展起来之后,会有不错的表现。
3、 什么环境下可以使用JAX?
1)科研:
如果涉及的工作中需要使用大量的科学计算,可以使用JAX来替换Matlab或者Numpy来加速计算。如果构建大型的深度学习项目,可以使用JAX提供的pmap方法来分布式训练提高效率。如果是自己搭建的全新模型,可以搭配Pytorch或TensorFlow并使用JAX来构建和优化流程。
2)云端部署:
如果项目运行在TPU上,不管是训练还是推理,一定要使用JAX来优化加速,毕竟是谷歌自己开发的,对自己的设备支持是最好的。如果深度学习项目是在常规GPU上,目前还是PyTorch和TensorFlow比较好,不过依然可以使用JAX来对某些计算环节进行加速(工程量可能比较大),最好是等基于Jax的一些高级深度学习库发展出较完善生态的时候在使用这些高级API去搭建项目。