根据离散概率随机返回int值
对Java语言的Random库,似乎只能产生服从正态分布 N ( 0 , 1 ) N(0, 1) N(0,1)和均匀分布 U ( 0 , 1 ) U(0, 1) U(0,1)的随机数,那么如何按照特定的概率生成随机数呢?
先考虑最简单的情况
X X X | 0 0 0 | 1 1 1 |
---|---|---|
P ( X = x ) P(X = x) P(X=x) | p p p | 1 − p 1-p 1−p |
容易想到,对于服从均匀分布
U
(
0
,
1
)
U(0, 1)
U(0,1)的随机变量,产生的随机数落在
[
0
,
p
)
[0, p)
[0,p)的概率就为
p
p
p, 而落在
[
p
,
1
)
[p, 1)
[p,1)的概率为
1
−
p
1-p
1−p。
接下来,我们考虑稍微复杂的情况:
X X X | 0 0 0 | 1 1 1 | 2 2 2 |
---|---|---|---|
P ( X = x ) P(X=x) P(X=x) | p 0 p_0 p0 | p 1 p_1 p1 | p 2 p_2 p2 |
产生的随机数落在 [ 0 , p 0 ) [0, p_0) [0,p0)的概率就为 p 0 p_0 p0, 而落在 [ p 0 , p 0 + p 1 ) [p_0, p_0+p_1) [p0,p0+p1)的概率为 p 1 p_1 p1,落在 [ p 0 + p 1 , 1 ) [p_0+p_1, 1) [p0+p1,1),也就是落在 [ p 0 + p 1 , p 0 + p 1 + p 2 ) [p_0+p_1, p_0+p_1+p_2) [p0+p1,p0+p1+p2)的概率为 p 2 p_2 p2,
以此类推,我们可以归纳出以下结论:
X X X | 0 0 0 | 1 1 1 | 2 2 2 | . . . ... ... | i i i | . . . ... ... | n n n |
---|---|---|---|---|---|---|---|
P ( X = x ) P(X=x) P(X=x) | p 0 p_0 p0 | p 1 p_1 p1 | p 2 p_2 p2 | … | p i p_i pi | … | p n p_n pn |
产生的随机数落在 [ ∑ k = 0 i − 1 p k , ∑ k = 0 i p k ) [\sum_{k=0}^{i-1}p_k, \sum_{k=0}^{i}p_k) [∑k=0i−1pk,∑k=0ipk)的概率为 p i p_i pi,由此我们可以考虑对 p p p进行累加来求解问题。考虑如下代码:
public static int discrete(double[] p) {
double r = uniform();
double sum = 0.0;
for (int i = 0; i < p.length; i++) {
sum = sum + p[i];
if (sum >= r) return i;
}
}
其中uniform()自定义的函数,能够返回满足
U
(
0
,
1
)
U(0, 1)
U(0,1)的随机数。
显然这样的代码是不够健壮的,因为缺乏对“p中各元素的和是否为1”这个情况的判断,故我们添加并修改以下语句:
if (p == null) return ERROR;
for (int i = 0; i < p.length; i++) {
if (!(p[i] >= 0.0))
return ERROR;
sum += p[i];
}
if (sum != 1.0)
return ERROR;
由于计算过程一些舍入误差,可能导致以上代码的sum不为1,故更稳健的方法是使用以下方式:
double EPSILON = 1E-14;
if (p == null) return ERROR;
for (int i = 0; i < p.length; i++) {
if (!(p[i] >= 0.0))
return ERROR;
sum += p[i];
}
if (sum > 1.0 + EPSILON || sum < 1.0 - EPSILON) return ERROR;
回到以下部分
double r = uniform();
double sum = 0.0;
for (int i = 0; i < p.length; i++) {
sum = sum + p[i];
if (sum >= r) return i;
}
考虑一种极端的情况,若随机产生的r很接近1,而p的和又由于舍入误差刚好略小于1,这就有可能导致以上函数没有返回值,故我们可以添加一个循环来解决这个问题,即
while (true) {
double r = uniform();
sum = 0.0;
for (int i = 0; i < p.length; i++) {
sum = sum + p[i];
if (sum > r) return i;
}
}
完整实现代码如下1:
package myrandom;
import java.util.Random;
public class MyRandom {
private static Random rand = new Random();
public static double uniform() {
return rand.nextDouble();
}
public static int discrete(double[] probabilities) {
if (probabilities == null) throw new IllegalArgumentException("argument array is null");
double EPSILON = 1E-14;
double sum = 0.0;
for (int i = 0; i < probabilities.length; i++) {
if (!(probabilities[i] >= 0.0))
throw new IllegalArgumentException("array entry " + i + " must be nonnegative: " + probabilities[i]);
sum += probabilities[i];
}
if (sum > 1.0 + EPSILON || sum < 1.0 - EPSILON)
throw new IllegalArgumentException("sum of array entries does not approximately equal 1.0: " + sum);
while (true) {
double r = uniform();
sum = 0.0;
for (int i = 0; i < probabilities.length; i++) {
sum = sum + probabilities[i];
if (sum > r) return i;
}
}
}
public static void main(String[] args) {
double[] p= {0.2, 0.3, 0.5};
System.out.println(MyRandom.discrete(p));
}
}
本人学识尚浅,对于文中一些表达或推导难免存在不当或纰漏之处,恳请各位批评指出
部分代码参考自Robert Sedgewick和Kevin Wayne《算法(第4版)》StdRandom.java的实现,详细可见Java Algorithms and Clients ↩︎