1.3 万 Star!迅猛发展的 JAX 对比 TensorFlow、PyTorch

JAX是谷歌推出的新一代机器学习框架,结合Autograd和XLA,提供自动微分、GPU/TPU支持等功能。与TensorFlow和PyTorch相比,JAX更专注于numpy的加速和便捷性,正在逐步获得研究者的关注。安装JAX后,可以通过grad()、jit()、pmap()和vmap()等特性进行高效计算。尽管TensorFlow和PyTorch各有优势,但JAX的动态性和易用性使其在科研领域崭露头角。
摘要由CSDN通过智能技术生成

JAX 是机器学习 (ML) 领域的新生力量,它有望使 ML 编程更加直观、结构化和简洁。

在机器学习领域,大家可能对 TensorFlow 和 PyTorch 已经耳熟能详,但除了这两个框架,一些新生力量也不容小觑,它就是谷歌推出的 JAX。很对研究者对其寄予厚望,希望它可以取代 TensorFlow 等众多机器学习框架。

JAX 最初由谷歌大脑团队的 Matt Johnson、Roy Frostig、Dougal Maclaurin 和 Chris Leary 等人发起。

目前,JAX 在 GitHub 上已累积 13.7K 星。

项目地址:https://github.com/google/jax

迅速发展的 JAX

JAX 的前身是 Autograd,其借助 Autograd 的更新版本,并且结合了 XLA,可对 Python 程序与 NumPy 运算执行自动微分,支持循环、分支、递归、闭包函数求导,也可以求三阶导数;依赖于 XLA,JAX 可以在 GPU 和 TPU 上编译和运行 NumPy 程序;通过 grad,可以支持自动模式反向传播和正向传播,且二者可以任意组合成任何顺序。

开发 JAX 的出发点是什么?说到这,就不得不提 NumPy。NumPy 是 Python 中的一个基础数值运算库,被广泛使用。但是 numpy 不支持 GPU 或其他硬件加速器,也没有对反向传播的内置支持,此外,Python 本身的速度限制阻碍了 NumPy 使用,所以少有研究者在生产环境下直接用 numpy 训练或部署深度学习模型。

在此情况下,出现了众多的深度学习框架,如 PyTorch、TensorFlow 等。但是 numpy 具有灵活、调试方便、API 稳定等独特的优势。而 JAX 的主要出发点就是将 numpy 的以上优势与硬件加速结合。

目前,基于 JAX 已有很多优秀的开源项目,如谷歌的神经网络库团队开发了 Haiku,这是一个面向 Jax 的深度学习代码库,通过 Haiku,用户可以在 Jax 上进行面向对象开发;又比如 RLax,这是一个基于 Jax 的强化学习库,用户使用 RLax 就能进行 Q-learning 模型的搭建和训练;此外还包括基于 JAX 的深度学习库 JAXnet,该库一行代码就能定义计算图、可进行 GPU 加速。可以说,在过

  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值