技术分享 | 能微分会加速的 NumPy —— JAX

目录

# 使用介绍 #

# 自动微分

# vmap 和 pmap

# JIT 编译

# 内部实现 #

# Trace 变换

# Jaxpr:JAX 中间表达式

# 总结 #

参考


JAX [1] 是 Google 推出的可以对 NumPy 和 Python 代码进行自动微分并跑到 GPU/TPU(Google 自研张量加速器)加速的机器学习库。Numpy [2] 是 Python 著名的数组运算库,官方版本只支持 CPU 运行(后面 Nvidia 推出的 CuPy 支持 GPU 加速,这里按住不表)。JAX 前身是 AutoGrad [3],2015 年哈佛大学来自物理系和 SEAS(工程与应用科学学院)的师生发表论文推出的支持 NumPy 程序自动求导的机器学习库。AutoGrad 提供和 NumPy 库一致的编程接口,用户导入 AutoGrad 就可以让原来写的 NumPy 程序拿来求导。JAX 在 2018 年将 XLA [4](Tensorflow 线性代数领域编译器)引入进来,使得 Python 程序可以通过 XLA 编译跑到 GPU/TPU 加速器上。简单地理解 JAX = NumPy + AutoGrad + XLA。可以说,XLA 加持下的 JAX,才真正具备了实施深度学习训练的基础和能力。

JAX Github: https://github.com/google/jax

JAX API Docs: https://jax.readthedocs.io/en/latest/


# 使用介绍 #

# 自动微分

JAX 提供兼容 NumPy 风格的接口,照顾用户原 NumPy 编程习惯。JAX 面向 Python 用户提供自动微分接口,包括生成梯度函数、求导等。

例 1:使用 jax.grad() 求导

from jax import grad

def f(x):
  return x * x * x

D_f = grad(f) # 3x^2
D2_f = grad(D_f) # 6x
D3_f = grad(D2_f) # 6

f(1.0) # 1.0
D_f(1.0) # 3.0
D2_f(1.0) # 6.0
D3_f(0.0) # 6.0 (always)

jax.grad:只接受输出标量的原始函数 f,生成对应的梯度函数 ▼f▼f 接受和原始函数一样的入参 x,输出为参数梯度 dx▼f 亦可被 grad(),相当于对原始函数计算高阶梯度,但需满足一样的要求:输出为标量。如果被求导的函数计算结果不止一个数值,不能直接传给 grad()。需要先 reduce 成一个标量。

例 2:对数组函数求导

from jax import numpy as np
from jax import grad
import matplotlib.pyplot as plt

def f(x):
    return x * x * x

D_f = grad(lambda x: np.sum(f(x)))
D2_f = grad(lambda x: np.sum(D_f(x)))
D3_f = grad(lambda x: np.sum(D2_f(x)))

x = np.linspace(-1, 1, 200)
plt.plot(x, f(x), x, D_f(x), x, D2_f(x), x, D3_f(x))
plt.show()

和例 1 相比主要区别在于例 2 分别对函数(fD_fD2_f)结果进行求和(sum)再求导。函数 f 和它的一阶、二阶、三阶导函数曲线如下图所示。

图片

JAX 支持不同模式自动微分。grad() 默认采取反向模式自动微分。另外显式指定模式的微分接口有 jax.vjp 和 jax.jvp

  • jax.vjp:反向模式自动微分。根据原始函数 f、输入 x 计算函数结果  y 并生成梯度函数 

  • 1
    点赞
  • 5
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值