JAX vs Tensorflow

什么是JAX

JAX 是一套google提出的基于tensorflow的机器学习框架,是tensorflow的简化板和优化版。从命名上可以看出它对tensorflow的优化,A == autograd(https://pytorch.org/docs/stable/autograd.html) 即加入了自动微分,X == XLA (即:Accelerated LinearAlgebra)。编译过tf的同学知道,XLA是编译时加入的控制参数,编译后的tf包可以加速数据流图执行,提升内存使用效率,降低自定义操作依赖,减小移动应用内存占用,以及增强平台可移植性。

为什么有JAX

JAX的特性

JAX的四大特性,jit,grad,vmap,pmap

1)JIT Just-In-Time

使用jit装饰器(just-in-time),编译和自动加速

@jit
def update(params, x, y):
  grads = grad(loss)(params, x, y)
  return [(w - step_size * dw, b- step_size) * db] for (w, b), (dw, db) in zip(params, grads)

2) pmap

pmap可以使数据在与多GPUS上并行计算。

3)vmap(vector

  • 1
    点赞
  • 1
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值