MT算法(Mersenne Twister算法)
(To 读者 :本篇文稿的主要内容是关于MT算法一些特点/优点,和一个常用的MT模型的结构,我把相关代码都附在文中了。整个行文是帮我同学做的算法总结,想着写都写了,不如发出来留个纪念,也可以帮到想了解MT算法的大家)
MT算法特性
Mersenne Twister算法,也就是梅森旋转算法,是一个伪随机数发生算法。由松本真和西村拓士在1997年开发,基于有限二进制字段上的矩阵线性递归。可以快速产生高质量的伪随机数,修正了古典随机数发生算法的很多缺陷,包括但不限于:
1.周期过短:
古典随机数生成算法通常有较短的周期,意味着生成的伪随机数在某个点后会重复。例如,线性同余生成器(LCG)的周期受限于其模数,而 Mersenne Twister 具有非常长的周期(2^{19937}−1),确保在极大的范围内不会出现重复伪随机数。
2.随机性不足:
早期的伪随机数生成器在统计性质上表现不佳,容易出现序列相关性,影响算法的均匀分布性。Mersenne Twister 使用了一种复杂的状态向量生成方式,通过多个独立的步骤和非线性变换改善了数列的随机性,使得它更接近于真正的随机数。
3.均匀分布性差:
某些传统算法生成的伪随机数在高维空间中表现不均匀,容易出现集聚现象,而 Mersenne Twister 通过更高维度的状态向量(624 维)和优化的生成过程,确保了伪随机数在高维空间内的良好均匀分布。
4.速度较慢:
传统的生成器在保证随机性的同时,往往计算速度较慢。Mersenne Twister 通过预先生成状态向量并进行批处理,能更快速地生成伪随机数,适合大规模数值模拟和高性能计算。
5.随机数质量不佳:
传统生成器在某些统计测试(如Spectral Test或Gap Test)中表现较差,可能产生偏差。Mersenne Twister 经过设计,能够通过一系列严格的统计测试,产生高质量的伪随机数。
常用MT算法变体
最为广泛使用Mersenne Twister的一种变体是MT19937,可以产生32位整数序列。具有以下的优点:
- 周期非常长,达到2−1。尽管如此长的周期并不必然意味着高质量的伪随机数,但短周期(比如许多旧版本软件包提供的2)确实会带来许多问题。
- 在1 ≤ k≤ 623的维度之间都可以均等分布(参见定义)。
- 除了在统计学意义上的不正确的随机数生成器以外,在所有伪随机数生成器法中是最快的(当时)
整个算法主要分为三个阶段:
第一阶段:获得基础的梅森旋转链;
第二阶段:对于旋转链进行旋转算法;
第三阶段:对于旋转算法所得的结果进行处理;
具体阶段图示如下所示:
算法实现的过程中,参数的选取取决于梅森素数,故此得名。
MT19937(伪代码表示)
初始化
首先将传入的seed赋给MT[0]作为初值,然后根据递推式:MT[i] = f × (MT[i-1] ⊕ (MT[i-1] >> (w-2))) + i递推求出梅森旋转链。伪代码如下:
// 由一个seed初始化随机数产生器
function seed_mt(int seed) {
index := n
MT[0] := seed
for i from 1 to (n - 1) {
MT[i] := lowest w bits of (f * (MT[i-1] xor (MT[i-1] >> (w-2))) + i)
}
}
对旋转链执行旋转算法
遍历旋转链,对每个MT[i],根据递推式:MT[i] = MT[i+m]⊕((upper_mask(MT[i]) || lower_mask(MT[i+1]))A)进行旋转链处理。
其中,“||”代表连接的意思,即组合MT[i]的高 w-r 位和MT[i+1]的低 r 位,设组合后的数字为x,则xA的运算规则为(x0是最低位):
x
A
=
{
x
≫
1
x
0
=
0
(
x
≫
1
)
⊕
a
x
0
=
1
\boldsymbol{x}A=\left\{\begin{array}{ll}\boldsymbol{x}\gg1&x_0=0\\(\boldsymbol{x}\gg1)\oplus\boldsymbol{a}&x_0=1\end{array}\right.
xA={x≫1(x≫1)⊕ax0=0x0=1
对应的伪代码如下:
lower_mask = (1 << r) - 1
upper_mask = !lower_mask
// 旋转算法处理旋转链
function twist() {
for i from 0 to (n-1) {
int x := (MT[i] & upper_mask)+ (MT[(i+1) mod n] & lower_mask)
int xA := x >> 1
if (x mod 2) != 0 {
// 最低位是1
xA := xA xor a
}
MT[i] := MT[(i + m) mod n] xor xA
}
index := 0
}
对旋转算法所得结果进行处理
设x是当前序列的下一个值,y是一个临时中间变量,z是算法的返回值。则处理过程如下:
y := x ⊕ ((x >> u) & d)
y := y ⊕ ((y << s) & b)
y := y ⊕ ((y << t) & c)
z := y ⊕ (y >> l)
伪代码如下:
// 从MT[index]中提取出一个经过处理的值
// 每输出n个数字要执行一次旋转算法,以保证随机性
function extract_number() {
if index >= n {
if index > n {
error "发生器尚未初始化"
}
twist()
}
int x := MT[index]
y := x xor ((x >> u) and d)
y := y xor ((y << s) and b)
y := y xor ((y << t) and c)
z := y xor (y >> l)
index := index + 1
return lowest w bits of (z)
}
由上述伪代码可以对应如下的C语言代码
MT19937(C语言实现)
#include <stdint.h>
// 定义MT19937-32的常数
enum
{
// 假定 W = 32 (此项省略)
N = 624,
M = 397,
R = 31,
A = 0x9908B0DF,
F = 1812433253,
U = 11,
// 假定 D = 0xFFFFFFFF (此项省略)
S = 7,
B = 0x9D2C5680,
T = 15,
C = 0xEFC60000,
L = 18,
MASK_LOWER = (1ull << R) - 1,
MASK_UPPER = (1ull << R)
};
static uint32_t mt[N];
static uint16_t index;
// 根据给定的seed初始化旋转链
void Initialize(const uint32_t seed)
{
uint32_t i;
mt[0] = seed;
for ( i = 1; i < N; i++ )
{
mt[i] = (F * (mt[i - 1] ^ (mt[i - 1] >> 30)) + i);
}
index = N;
}
static void Twist()
{
uint32_t i, x, xA;
for ( i = 0; i < N; i++ )
{
x = (mt[i] & MASK_UPPER) + (mt[(i + 1) % N] & MASK_LOWER);
xA = x >> 1;
if ( x & 0x1 )
{
xA ^= A;
}
mt[i] = mt[(i + M) % N] ^ xA;
}
index = 0;
}
// 产生一个32位随机数
uint32_t ExtractU32()
{
uint32_t y;
int i = index;
if ( index >= N )
{
Twist();
i = index;
}
y = mt[i];
index = i + 1;
y ^= (y >> U);
y ^= (y << S) & B;
y ^= (y << T) & C;
y ^= (y >> L);
return y;
}
MT19937(Python实现)
在讨论之前,引入MT19937-32的生成python代码:(此代码在 [0,2^32-1] 生成的伪随机数基本大致相同)
def _int32(x):
return int(0xFFFFFFFF & x)
class MT19937:
# 根据seed初始化624的state
def __init__(self, seed):
self.mt = [0] * 624
self.mt[0] = seed
self.mti = 0
for i in range(1, 624):
self.mt[i] = _int32(1812433253 * (self.mt[i - 1] ^ self.mt[i - 1] >> 30) + i)
# 提取伪随机数
def extract_number(self):
if self.mti == 0:
self.twist()
y = self.mt[self.mti]
y = y ^ y >> 11
y = y ^ y << 7 & 2636928640
y = y ^ y << 15 & 4022730752
y = y ^ y >> 18
self.mti = (self.mti + 1) % 624
return _int32(y)
# 对状态进行旋转
def twist(self):
for i in range(0, 624):
y = _int32((self.mt[i] & 0x80000000) + (self.mt[(i + 1) % 624] & 0x7fffffff))
self.mt[i] = (y >> 1) ^ self.mt[(i + 397) % 624]
if y % 2 != 0:
self.mt[i] = self.mt[i] ^ 0x9908b0df`
接下来,我们观察上面MT19937的代码,我们可以发现代码分为四个部分:
接下来,我们观察上面MT19937的代码,我们可以发现代码分为四个部分:
一、_int32(x)模块
返回一个32位的二进制代码。
二、init(self, seed)
首先,我们必须要知道seed在代码中是种子,意思是基于已知的seed生成624个state块(伪随机数通过对不同的state块进行变换求得),我们先将state的第一个数值定为seed,代码中的623个循环便是通过state间的变换求出求出剩下的state块。
三、extract_number(self)
MT19937算法通过此模块来得到不同的伪随机数。首先,我们先进行判断,如果此时self.mti指向第一个state,我们运行__init__(self, seed):得到623个state值,如果不是,则直接进入下面的伪随机数生成过程:用通过seed求得的state值进行代码中的变换求得并返回我们所需的伪随机数。
四、twist(self)
如果只有上面的块,那么只能求得624个不同的伪随机数,但是MT19937-32却可以求出2^32-1个不同的伪随机数便是因为这个模块。旋转模块基于上一次循环中我们已经使用过的624个state值,一一对应,通过原代码中的:
for i in range(0, 624):
y = _int32((self.mt[i] & 0x80000000) + (self.mt[(i + 1) % 624] & 0x7fffffff))
self.mt[i] = (y >> 1) ^ self.mt[(i + 397) % 624]
if y % 2 != 0:
self.mt[i] = self.mt[i] ^ 0x9908b0df
求得新的624个与上一次循环中不同的state值,并进入新的循环中。