random.PRNGKey
from jax import random
key = random.PRNGKey(1)
print(key)
PRNGKey会生成一个(2,)shape array来作为seed的值
output: [0 1]
在未来需要生成随机数的时候,可以直接使用key值来作为seed,方便操作。
x = random.normal
from jax import random
key = random.PRNGKey(1)
print(key)
PRNGKey会生成一个(2,)shape array来作为seed的值
output: [0 1]
在未来需要生成随机数的时候,可以直接使用key值来作为seed,方便操作。
x = random.normal