设置随机数种子
random.PRNGKey() 是jax中用于生成伪随机数生成器(PRNG)密钥的函数.
基本用法:
import jax
from jax import random
# 生成一个随机数密钥
key = random.PRNGKey(42)
42是随机数生成器的种子,key是一个包含两个整数的jax数组,表示随机数生成器的状态. 具体来说,这两个整数的含义如下:
第一个整数:表示伪随机数生成器的主要部分,它用于确定随机数生成的当前状态。
第二个整数:作为辅助部分,用于帮助确定生成的随机数序列,并确保生成的随机数在并行计算中是独立的。
使用生成的随机数密钥
生成随机数密钥后,可以将其用于各种随机数生成函数.每次使用密钥后,最好通过拆分密钥来生成新的子密钥,避免使用同一个密钥生成多个随机数.例如:
import jax
import jax.numpy as jnp
from jax import random
# 生成一个随机数密钥
key = random.PRNGKey(42)
# 拆分密钥
key1, key2 = random.split(key)
# 使用不同的子密钥生成随机数
random_numbers1 = random.uniform(key1, shape=(3,))
random_numbers2 = random.uniform(key2, shape=(3,))
print(random_numbers1)
print(random_numbers2)
拆分出来的key1, key2是两个新的子密钥,不同于key.