什么是JAX
JAX 是一套google提出的基于tensorflow的机器学习框架,是tensorflow的简化板和优化版。从命名上可以看出它对tensorflow的优化,A == autograd(https://pytorch.org/docs/stable/autograd.html) 即加入了自动微分,X == XLA (即:Accelerated LinearAlgebra)。编译过tf的同学知道,XLA是编译时加入的控制参数,编译后的tf包可以加速数据流图执行,提升内存使用效率,降低自定义操作依赖,减小移动应用内存占用,以及增强平台可移植性。
为什么有JAX
JAX的特性
JAX的四大特性,jit,grad,vmap,pmap
1)JIT Just-In-Time
使用jit装饰器(just-in-time),编译和自动加速
@jit
def update(params, x, y):
grads = grad(loss)(params, x, y)
return [(w - step_size * dw, b- step_size) * db] for (w, b), (dw, db) in zip(params, grads)
2) pmap
pmap可以使数据在与多GPUS上并行计算。
3)vmap(vector