Numpy和JAX中的随机数

前言

本文主要翻译自JAX在github上的一篇文档(Authors: Matteo Hessel & Rosalia Schneider),同时增加了部分个人理解。
原文链接如下:
https://github.com/google/jax/blob/main/docs/jax-101/05-random-numbers.md

关于伪随机数的生成,pseudo random number generation:
伪随机数并非真正的随机数,伪随机数是根据一定的算法,依据初始值(种子,key)生成的数值,当算法不变时,生成的结果不变。

在各个方面,JAX力求和Numpy保持一致,而在伪随机数方面是一个例外。下面将具体介绍一下JAX和Numpy之间关于伪随机数的区别。

Numpy中的伪随机数

numpy中的伪随机数通过np.random生成,伪随机数的状态是全局统一的(原文:In NumPy, pseudo random number generation is based on a global state)。
怎么理解这句话呢?我理解在Numpy中,所有的随机数状态均是基于一种算法生成的,即所有的随机数的状态均在一个序列当中。
这里再介绍一下状态,即state。在Numpy当中,有一个方法是np.random.get_state(),在官方文档中,解释为:Return a tuple representing the internal state of the generator。即返回一个代表生成器内部状态的元组。我理解这个状态和种子的概念是类似的,在同一状态下,得到的随机数是相同的。
我们看看这个状态具体是个什么东东:

def print_truncated_random_state():
  """To avoid spamming the outputs, print only part of the state."""
  full_random_state = np.random.get_state()
  print(str(full_random_state)[:460], '...')
  
print_truncated_random_state()

('MT19937', array([2443250962, 1093594115, 1878467924, 2709361018, 1101979660,
       3904844661,  676747479, 2085143622, 1056793272, 3812477442,
       2168787041,  275552121, 2696932952, 3432054210, 1657102335,
       3518946594,  962584079, 1051271004, 3806145045, 1414436097,
       2032348584, 1661738718, 1116708477, 2562755208, 3176189976,
        696824676, 2399811678, 3992505346,  569184356, 2626558620,
        136797809, 4273176064,  296167901, 343 ...

插曲:可以看到,输出结果中有"MT19937" 个东东,这个是个什么东西了?
查了一下:MT19937表示一个伪随机数生成算法。
梅森旋转算法(Mersenne twister)是一个伪随机数发生算法。由松本真和西村拓士[1]在1997年开发,基于有限二进制字段上的矩阵线性递归 F 2 F_2 F2。可以快速产生高质量的伪随机数,修正了古典随机数发生算法的很多缺陷。
Mersenne Twister这个名字来自周期长度取自梅森质数的这样一个事实。这个算法通常使用两个相近的变体,不同之处在于使用了不同的梅森素数。一个更新的和更常用的是MT19937, 32位字长。还有一个变种是64位版的MT19937-64。对于一个k位的长度,Mersenne Twister会在 [ 0 , 2 k − 1 ] [0,2^k-1] [0,2k1]的区间之间生成离散型均匀分布的随机数。

在每次调用Numpy时,state都会更新:

np.random.seed(0)
print_truncated_random_state()

('MT19937', array([         0,          1, 1812433255, 1900727105, 1208447044,
       2481403966, 4042607538,  337614300, 3232553940, 1018809052,
       3202401494, 1775180719, 3192392114,  594215549,  184016991,
        829906058,  610491522, 3879932251, 3139825610,  297902587,
       4075895579, 2943625357, 3530655617, 1423771745, 2135928312,
       2891506774, 1066338622,  135451537,  933040465, 2759011858,
       2273819758, 3545703099, 2516396728, 127 ...
_ = np.random.uniform()
print_truncated_random_state()

('MT19937', array([2443250962, 1093594115, 1878467924, 2709361018, 1101979660,
       3904844661,  676747479, 2085143622, 1056793272, 3812477442,
       2168787041,  275552121, 2696932952, 3432054210, 1657102335,
       3518946594,  962584079, 1051271004, 3806145045, 1414436097,
       2032348584, 1661738718, 1116708477, 2562755208, 3176189976,
        696824676, 2399811678, 3992505346,  569184356, 2626558620,
        136797809, 4273176064,  296167901, 343 ...

同时我们也可以把状态保存下来,后面可以读取这个状态来返回相同的值。

state=np.random.get_state()  #获取并保存state
np.random.uniform()
0.6027633760716439
np.random.set_state(state) #读取已保存的状态
np.random.uniform()
0.6027633760716439  ---两个随机数是一样的

在Numpy里面,不仅一次可以获取一个随机数,还可以获取一个随机的向量或是张量。

np.random.seed(0)
print(np.random.uniform(size=3))
[0.5488135  0.71518937 0.60276338]

Numpy中一个比较有意思的东西是,它提供了一个顺序等价保证,怎么理解呢?就是同样是3个数字,分3次取3个数和一次取一个包含3个元素的向量得到的结果是一样的。
如下:

np.random.seed(0)
print("individually:", np.stack([np.random.uniform() for _ in range(3)]))
individually: [0.5488135  0.71518937 0.60276338]
np.random.seed(0)
print("all at once: ", np.random.uniform(size=3))
all at once:  [0.5488135  0.71518937 0.60276338]

是不是很神奇?这个如果在搞明白为什么会这样,估计是要研究下Numpy的random.uniform的实现代码了,我估计np.random.uniform(size=3)也是循环生成的。

JAX中的伪随机数

JAX中的伪随机数与Numpy中有很大不同,Numpy中的随机数设计满足不了JAX的需求。JAX要求具备以下特点:

  1. 可重复,reproducible ——个人理解意思为可复现,即重复操作时结果是一样的(就是种子的意思)
  2. 可并行,parallelizable
  3. 向量化,vectorisable
    下面我们来具体说明。
    首先我们看看一个全局状随机数的含义。上代码:
import numpy as np
np.random.seed(0)
def bar(): return np.random.uniform()
def baz(): return np.random.uniform()
def foo(): return bar() + 2 * baz()
print(foo())
------------------
1.9791922366721637

上面的代码中,方法foo表示两个服从均匀分布的标量之和。
在假设方法bar和方法baz是按一定顺序执行的情况下,计算结果才能满足JAX的三点要求的第一条,即reproducible。
啥意思呢?每次运行代码时,必须保证bar与baz的执行顺序相同,得到的结果才相同。如果第一次先执行bar再执行baz,第二次先执行baz再执行bar,这两次的结果是不一样的。
这个现象在Numpy里似乎是无关紧要,不是啥问题,本来也是会按序执行的。但是这个东东在JAX里就不行了。。。
为啥呢?因为JAX是支持并行的!
这段代码如果想在JAX里复现,那就得强制按顺序执行,但是bar和baz两个方法没有依赖,在编译时,会被搞成并行执行的。这个就违背了JAX的第二点要求:parallelizable!

为了解决这个问题,在JAX中,不再使用全局状态,也就是不再设置全局的种子了。随机函数显式的、明确的消费随机状态,这个状态在JAX一般用“key”表示。
怎么理解呢?我们先看看这个key具体是什么:

from jax import random
key = random.PRNGKey(42)
print(key)
WARNING:absl:No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)
------------------
[ 0 42]

key是一个shape为2的数组[0,42]。
其实这个key也就是numpy中的seed啦,只是在numpy中,seed只需要设置一次,但是在JAX里,只要用到了random的方法,需要明确的指定key,就是每次调用random的方法都要单独设置key,下文会见到。
random方法使用key,但是并不会改变它。同样,当一个random方法消费相同的key时,得到的结果也是一样的。

print(random.normal(key))
print(random.normal(key))
-0.18471184
-0.18471184

需要注意的是:
当不同的随机函数使用相同的key时,得到的结果是存在相关性的,这在一般情况不是我们想要的,我们希望要的东西是独立的。(PS:相关和独立就是概率里的那个概念)
不要重复使用key,不要重复使用key,不要重复使用key。。。

那么问题来了,一个key不能重复使用,但是又要求每个random方法都要明确的指定key,还要结果可重复,还让不让人玩儿了??

简单~!JAX在设计里肯定考虑这个了,不然早就混不下去了。怎么解决呢,就是把一个key给掰成几瓣!看代码:

from jax import random
key = random.PRNGKey(42)
print("old key", key)
new_key, subkey = random.split(key)
del key  # The old key is discarded -- we must never use it again.
normal_sample = random.normal(subkey)
print(r"    \---SPLIT --> new key   ", new_key)
print(r"             \--> new subkey", subkey, "--> normal", normal_sample)
del subkey  # The subkey is also discarded after use.
# Note: you don't actually need to `del` keys -- that's just for emphasis.
# Not reusing the same values is enough.
key = new_key  # If we wanted to do this again, we would use new_key as the key.
----------------------------------------------------------------------------------
old key [ 0 42]
    \---SPLIT --> new key    [2465931498 3679230171]
             \--> new subkey [255383827 267815257] --> normal 1.3694694

首先:split()是一个确定性函数(输入相同时,输出总是一样的),可以把一个key分裂成多个相互独立的key,并且还能满足伪随机的性质。上面的代码中,我们通过把一个老K,分裂成了两个小K。分裂之后的小K,还可以继续发展“下线”,继续向下分裂。
总之,只要保证不同的random方法使用不同的key就可以了。还有就是,做为母K,分裂之后,就不要再用了,是什么原因原文没有介绍,我也没查到。

无论是谁叫Key还是叫subkey都是不重要的,它们都是在相同状态下的伪随机数。上面那个例子,一般写成如下的形式,这时,老的key会被自动discarded,key被赋值分裂后的伪随机数。

key, subkey = random.split(key)

当然了,split不只是能分裂出两个子key,你想要几个都可以。

key, *forty_two_subkeys = random.split(key, num=43)

Numpy和JAX的random模型的另一个区别是顺序等价保证( the sequential equivalence guarantee),就是上面提到的执行顺序的问题。
JAX也是和Numpy一样,可以生成一个多维随机向量,但是JAX并不提供 the sequential equivalence guarantee,因为那样的话会影响在SIMD(单指令多数据结构)硬件上的向量化,也就是前面提到的第3点:vectorisable。
再看看前面一个例子在JAX中情况:

key = random.PRNGKey(42)
subkeys = random.split(key, 3)
sequence = np.stack([random.normal(subkey) for subkey in subkeys])
print("individually:", sequence)

key = random.PRNGKey(42)
print("all at once: ", random.normal(key, shape=(3,)))

individually: [-0.04838839  0.10796146 -1.2226542 ]
all at once:  [ 0.18693541 -1.2806507  -1.5593133 ]

我们没办法再得到两个相同的结果了。

同时需要注意的是,这里面我们把母key也用了,原文是说因为在别的地方没使用,所以没有违反只使用一次的原则。但是我个人觉得这个违反了分裂就丢弃的原则了啊!!

以上就是全文了,由于个人水平有限,理解的不一定正确,欢迎大家一起讨论~~

评论 3
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值