Google最新开源机器学习框架,Github已超18万Stars!

Google最新开源机器学习框架,Github已超18万Stars!


前言

    近年来深度学习框架市场依然呈现老牌两巨头TensorFlow和PyTorch的相爱相杀,不过由于TF这几年不断地更新而变得越来越臃肿难用,大家已经逐渐开始抛弃它(作者本人也已经转向了PyTorch,毕竟简单好用)。不过Google毕竟是大公司,肯定不能这么就随便让竞争对手的产品任意发展,因此也就有了今天本文所要介绍的最新框架JAX。

1、JAX是什么

    JAX是Google于2018年开发并开源的科学计算库,其描述为Autograd和TensorFlow XLA的结合体。Autograd是pytorch计算反向传播时候的提供的一个组件,可以实现梯度的自动计算与更新。XLA(Accelerated Linear Algebra)是TensorFlow中内置的通过编译线性代数加速数据计算以提升机器学习模型速度的组件。可以这样理解,JAX就是一个在执行一些科学计算的时候,能够提供强大加速能力的工具。
    官方文档:https://jax.readthedocs.io/en/latest/jax-101/index.html
    项目地址:https://github.com/google/jax


2、 JAX能够用来做什么?

图1 开发团队对JAX项目上的介绍
    从开发团队对目前JAX的一些功能和特性的描述来看,JAX从设计方向上并不是一个深度学习库或者框架,相比较于TensorFlow和PyTorch这两个拥有完善生态的深度学习框架,JAX所能够提供的更多是一个组合计算库,同时包含了部分对深度学习的支持。 目前JAX主要对标老牌计算库Numpy。Numpy很多计算不支持并行计算,无法调动多核计算性能。同时Numpy的计算仅限于CPU,对于标量计算来说,CPU和GPU的性能差异不大,但是在矩阵向量计算上,GPU的并行计算能力则可以完全打败CPU。在Pytorch中,我们可以规定让Tensor对象在GPU上计算并进行加速,同理JAX提供了API可以让我们在GPU/TPU上进行计算加速。 JAX提供标准的Python Package库,拥有CPU版本、GPU(CUDA)版本还有TPU版本,目前Jax由于要使用XLA而必须安装jaxlib库,jaxlib库仅支持Linux(Ubuntu16.04 +)和MacOs(10.12+)平台上安装,如果要在Windows上安装,则需要使用WSL虚拟机。
pip install --upgrade "jax[cpu]" # cpu版本
pip install --upgrade "jax[cuda]" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html # cuda版本
pip install "jax[tpu]>=0.2.16" -f https://storage.googleapis.com/jax-releases/libtpu_releases.html # tpu版本

具体安装指南可以在官方Github上面看https://github.com/google/jax#installation

有一些例子可以让我们简单的了解一下JAX可以做些什么(我电脑没有gpu所以下载的是cpu版本的jax)。
先构建一个三次方程然后使用numpy来进行计算:

1.	import numpy as np  
2.	import time  
3.	  
4.	  
5.	def function(x):  
6.	    # 构建一个三次方程  
7.	    return x * x * x + x * x + x  
8.	  
9.	  
10.	x = np.random.randn(10000, 10000)  
11.	start_time = time.time()  
12.	for i in range(5):  
13.	    loop_start = time.time()  
14.	    ans = function(x)  
15.	    loop_stop = time.time()  
16.	    print(f'Loop {i+1} costs: {loop_stop - loop_start: .3f}')  
17.	stop_time = time.time()  
18.	print(f'5 loops cost total {stop_time - start_time: .3f}') 

Numpy计算三次方程的结果

图2 Numpy计算三次方程的结果
使用Jax运行:
1.	import numpy as np  
2.	import time  
3.	import jax  
4.	import jax.numpy as jnp  
5.	  
6.	  
7.	def function(x):  
8.	    # 构建一个三次方程  
9.	    return x * x * x + x * x + x  
10.	  
11.	  
12.	# 用Jax对三次方程进行计算并计时  
13.	x = np.random.randn(10000, 10000)  
14.	jax_function = jax.jit(function)  
15.	x = jnp.array(x)  
16.	start_time = time.time()  
17.	for i in range(5):  
18.	    loop_start = time.time()  
19.	    ans = jax_function(x)  
20.	    loop_stop = time.time()  
21.	    print(f'Loop {i+1} costs: {loop_stop - loop_start: .3f}')  
22.	stop_time = time.time()  
23.	print(f'5 loops cost total {stop_time - start_time: .3f}')

在这里插入图片描述

图3 使用JAX计算三次方程的结果
    两份代码均在CPU上运行,而JAX已经实现了3倍以上的加速效果,如果能够在GPU上运行,将会更加大幅度的提升计算速度,而这仅仅是一个简单的方程计算。Numpy能够实现的各种复杂运算(比如矩阵相乘,点积计算)Jax都兼容,并且有强大加速能力加持。由于Jax提供了jax.numpy库,几乎拥有和Numpy相同的API,因此使用Jax来替代Numpy进行这种科学计算任务就成为了使用Jax的一大理由之一。

在JAX的官方文档也里面举了很多的代码例子来进行说明:https://github.com
/google/jax#installation

    当然JAX不仅仅只是一个超级加速过后的Numpy库,它同时提供了一个扩展系统用来转换函数。JAX目前提供用来转换函数的方法:
  1) grad():自动计算微分
  2) jit():即时编译加速代码(示例在上面已经给出)
  3) vmap():自动向量化
  4) pmap():并行编程
grad示例:

1import time  
2import autograd.numpy as np  
3from autograd import grad  
456def tanh(x):  
7)	    y = np.exp(-2.0 * x)  
8return (1.0 - y) / (1.0 + y)  
910# 使用Autograd来计算微分  
11)start_time = time.time()  
12print(grad(grad(grad(tanh)))(1.0))  
13)end_time = time.time()  
14print(f'Cost {end_time - start_time:.8f} s')  

在这里插入图片描述

图4 使用Autograd计算微分的结果

jax实现:

1import time  
2import autograd.numpy as np  
3from autograd import grad  
456def tanh(x):  
7)	    y = np.exp(-2.0 * x)  
8return (1.0 - y) / (1.0 + y)  
91011# 使用Autograd来计算微分  
12)start_time = time.time()  
13print(jit(grad(jit(grad(jit(grad(jit(tanh)))))))(1.0)) 
14)end_time = time.time()  
15print(f'Cost {end_time - start_time:.8f} s') 

在这里插入图片描述

图5 使用Jax计算微分的结果

    这里进行对比的时候不知道为什么Autograd和Jax的计算时间会相差如此之大,这里推测可能是Jax操作的时候会先找GPU,因为运行环境没有GPU因此Jax将函数转移到了CPU内存上进行了操作,可能是这一次I/O操作导致了较高的计算时间。
由于运行环境没有GPU,因此对ymap和pmap就不做测试和展示了,具体的代码示例在项目Readme: https://github.com/google/jax#transformations
    总结一下:
    除了以上一些最基本的API之外,JAX还提供了其他的一些用于科学计算的API,可以阅读https://jax.readthedocs.io/en/latest/来查看这些库的用法。整体上JAX还是以计算任务为主,矩阵相乘、求导之类的计算在机器学习或深度学习中很常见,因此使用JAX去加速这些计算能够最大限度的发挥硬件性能,最大限度的做性能优化。当然,JAX还没有完整的生态,如果要在一个完整的大型深度学习项目中使用JAX,仍然需要依赖其他的第三方库,并且需要使用者有很好的数学背景,毕竟如果要使用JAX来设计一个模型,则需要将整个计算图中的节点均以函数的形式一一实现用JAX来实现。在深度学习方面,一些基于JAX强大计算性能的高级库如Flax、Haiku等,可能在未来生态逐渐发展起来之后,会有不错的表现。

3、 什么环境下可以使用JAX?

1)科研:
    如果涉及的工作中需要使用大量的科学计算,可以使用JAX来替换Matlab或者Numpy来加速计算。如果构建大型的深度学习项目,可以使用JAX提供的pmap方法来分布式训练提高效率。如果是自己搭建的全新模型,可以搭配Pytorch或TensorFlow并使用JAX来构建和优化流程。
2)云端部署:
    如果项目运行在TPU上,不管是训练还是推理,一定要使用JAX来优化加速,毕竟是谷歌自己开发的,对自己的设备支持是最好的。如果深度学习项目是在常规GPU上,目前还是PyTorch和TensorFlow比较好,不过依然可以使用JAX来对某些计算环节进行加速(工程量可能比较大),最好是等基于Jax的一些高级深度学习库发展出较完善生态的时候在使用这些高级API去搭建项目。

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值