jax循环语句

import jax
import jax.numpy as jnp

### 1. jax.lax.while_loop
# jax.lax.while_loop(cond_fun, body_fun, init_val)
init_val = 0
cond_fun = lambda x: x<10
body_fun = lambda x: x+1
result = jax.lax.while_loop(cond_fun, body_fun, init_val)
print(result)

### 2. jax.lax.fori_loop
# jax.lax.fori_loop(lower, upper, body_fun, init_val, *, unroll=None)
init_val = 0
start = 0
stop = 10
body_fun = lambda i,x: x+i
result = jax.lax.fori_loop(start, stop, body_fun, init_val)
print(result)

### 3.jax.lax.scan
# jax.lax.scan(f, init, xs, length=None, reverse=False, unroll=1)
# 使用scan的时候,carry的变量也是需要显式地定义在函数中,并且是return的第一个变量

def f(carry, x):
    x = carry + x
    return x, x

xs = jnp.array([0, 1, 2, 3,])
result = jax.lax.scan(f, 0, xs)
print(result)

#haiku.scan(f, init, xs, length=None, reverse=False, unroll=1)[source]
#Equivalent to jax.lax.scan() but with Haiku state passed in/out.

参考:

https://jax.readthedocs.io/en/latest/_autosummary/jax.lax.while_loop.html

https://jax.readthedocs.io/en/latest/_autosummary/jax.lax.while_loop.html

https://jax.readthedocs.io/en/latest/_autosummary/jax.lax.scan.html#

https://dm-haiku.readthedocs.io/en/latest/api.html?highlight=haiku.scan#haiku.scan

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值