TensorFlow官方简化版!谷歌开源机器学习库JAX

AI前线导读:什么?TensorFlow 有了替代品?什么?竟然还是谷歌自己做出来的?先别慌,从各种意义上来说,这个所谓的“替代品”其实是 TensorFlow 的一个简化库,名为 JAX,结合 Autograd 和 XLA,可以支持部分 TensorFlow 的功能,但是比 TensorFlow 更加简洁易用。虽然还不至于替代 TensorFlow,但已经有 Reddit 网友对 JAX 寄予厚望,并表示“早就期待能有一个可以直接调用 Numpy API 接口的库了!”,“希望它可以取代 TensorFlow!”。

更多干货内容请关注微信公众号“AI前线”(ID:ai-front)

JAX结合了Autograd和XLA,是专为高性能机器学习研究打造的产品。

\"image\"

有了新版本的Autograd,JAX能够自动对Python和NumPy的自带函数求导,支持循环、分支、递归、闭包函数求导,而且可以求三阶导数。它支持自动模式反向求导(也就是反向传播)和正向求导,且二者可以任意组合成任何顺序。

JAX的创新之处在于,它基于XLA在GPU和TPU上编译和运行NumPy程序。默认情况下,编译是在底层进行的,库调用能够及时编译和执行。但是JAX还允许使用单一函数API jit将自己的Python函数及时编译成经过XLA优化的内核。编译和自动求导可以任意组合,因此可以在不脱离Python环境的情况下实现复杂算法并获得最优性能。

JAX最初由Matt Johnson、Roy Frostig、Dougal Maclaurin和Chris Leary发起,他们均任职于谷歌大脑团队。在GitHub的说明文档中,作者明确表示:JAX目前还只是一个研究项目,不是谷歌的官方产品,因此可能会有一些bug。从作者的GitHub简介来看,这应该是谷歌大脑正在尝试的新项目,在同一个GitHub目录下的开源项目还包括8月份在业内引起热议的强化学习框架Dopamine。

以下是JAX的简单使用示例。

\"image\"

GitHub项目传送门:https://github.com/google/JAX

有关具体的安装和简单的入门指导大家可以在GitHub中自行查看,在此不做过多赘述。

JAX库的实现原理

机器学习中的编程是关于函数的表达和转换。转换包括自动微分、加速器编译和自动批处理。像Python这样的高级语言非常适合表达函数,但是通常使用者只能应用它们。我们无法访问它们的内部结构,因此无法执行转换。

JAX可以用于专门化高级Python+NumPy函数,并将其转换为可转换的表示形式,然后再提升为Python函数。

\"image\"

JAX通过跟踪专门处理Python函数。跟踪一个函数意味着:监视应用于其输入,以产生其输出的所有基本操作,并在有向无环图(DAG)中记录这些操作及其之间的数据流。为了执行跟踪,JAX包装了基本的操作,就像基本的数字内核一样,这样一来,当调用它们时,它们就会将自己添加到执行的操作列表以及输入和输出中。为了跟踪这些原语之间的数据流,跟踪的值被包装在Tracer类的实例中。

当Python函数被提供给grad或jit时,它被包装起来以便跟踪并返回。当调用包装的函数时,我们将提供的具体参数抽象到AbstractValue类的实例中,将它们框起来用于跟踪跟踪器类的实例,并对它们调用函数。

抽象参数表示一组可能的值,而不是特定的值:例如,jit将ndarray参数抽象为抽象值,这些值表示具有相同形状和数据类型的所有ndarray。相反,grad抽象ndarray参数来表示底层值的无穷小邻域。通过在这些抽象值上跟踪Python函数,我们确保它足够专门化,以便转换是可处理的,并且它仍然足够通用,以便转换后的结果是有用的,并且可能是可重用的。然后将这些转换后的函数提升回Python可调用函数,这样就可以根据需要跟踪并再次转换它们。

JAX跟踪的基本函数大多与XLA HLO 1:1对应,并在lax.py中定义。这种1:1的对应关系使得到XLA的大多数转换基本上都很简单,并且确保我们只有一小组原语来覆盖其他转换,比如自动微分。 jax.numpy层是用纯Python编写的,它只是用LAX函数(以及我们已经编写的其他numpy函数)表示numpy函数。这使得jax.numpy易于延展。

当你使用jax.numpy时,底层LAX原语是在后台进行jit编译的,允许你在加速器上执行每个原语操作的同时编写不受限制的Python+ numpy代码。

但是JAX可以做更多的事情:你可以在越来越大的函数上使用jit来进行端到端编译和优化,而不仅仅是编译和调度到一组固定的单个原语。例如,可以编译整个网络,或者编译整个梯度计算和优化器更新步骤,而不仅仅是编译和调度卷积运算。

折衷之处是,jit函数必须满足一些额外的专门化需求:因为我们希望编译专门针对形状和数据类型的跟踪,但不是专门针对具体值的跟踪,所以jit装饰器下的Python代码必须适用于抽象值。如果我们尝试在一个抽象的x上求x \u0026gt;0的值,结果是一个抽象的值,表示集合{True, False},所以Python分支就像if x \u0026gt; 0会引起报错。

有关使用jit的更多要求,请参见:https://github.com/google/jax#whats-supported

好消息是,jit是可选的:JAX库在后台对单个操作和函数使用jit,允许编写不受限制的Python+Numpy,同时仍然使用硬件加速器。但是,当你希望最大化性能时,通常可以在自己的代码中使用jit编译和端到端优化更大的函数。

后续计划

目前项目小组还将对以下几项做更多尝试和更新:

  1. 完善说明文档

  2. 支持Cloud TPU

  3. 支持多GPU和多TPU

  4. 支持完整的NumPy功能和部分SciPy功能

  5. 全面支持vmap

  6. 加速

    1. 降低XLA函数调度开销
    2. 线性代数例程(CPU上的MKL和GPU上的MAGMA)
  7. 高效自动微分原语condwhile

有关JAX库的介绍大致如此,如果你在尝试了JAX之后有一些较好的使用心得,欢迎随时向我们投稿,AI前线十分愿意将你的经验传播给更多开发者。

再次附上GitHub链接:https://github.com/google/jax

相关资源:

JAX论文链接:https://www.sysml.cc/doc/146.pdf

会议推荐


12月20-21, AICon将于北京开幕,在这里可以学习来自Google、微软、BAT、360、京东、美团等40+AI落地案例,与国内外一线技术大咖面对面交流。

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值