有一个在大数据下很现实的例子:
“给出一个数据流,这个数据流的长度很大或者未知。并且对该数据流中数据只能访问一次。请写出一个随机选择算法,使得数据流中所有数据被选中的概率相等。”
解决这个问题既需要算法设计,又需要一些概率论知识,因此对于大多数人,起码包括我,这不是一个立刻就能想出答案的问题。
解决这个问题的算法叫蓄水池采样(Reservoir Sampling)算法。本篇博客介绍该算法的原理、证明和代码实现。
一、原理
介绍该算法之前,我们首先从最简单的例子出发:假设数据流只有一个数据。我们接收数据,发现数据流结束了,直接返回该数据,该数据返回的概率为1。看来很简单,那么我们试试难一点的情况:假设数据流里有两个数据。
我们读到了第一个数据,这次我们不能直接返回该数据,因为数据流没有结束。我们继续读取第二个数据,发现数据流结束了。因此我们只要保证以相同的概率返回第一个或者第二个数据就可以满足题目要求。因此我们生成一个0到1的随机数R,如果R小于0.5,我们就返回第一个数据,如果R大于0.5,返回第二个数据。
接着我们继续分析有三个数据的数据流的情况。为了方便,我们按顺序给流中的数据命名为1、2、3。我们陆续收到了数据1、2。和前面的例子一样,我们只能保存一个数据,所以必须淘汰1和2中的一个。应该如何淘汰呢?不妨和上面例子一样,我们按照二分之一的概率淘汰一个,例如我们淘汰了2。继续读取流中的数据3,发现数据流结束了,我们知道在长度为3的数据流中,如果返回数据3的概率为1/3,那么才有可能保证选择的正确性。也就是说,目前我们手里有1、3两个数据,我们通过一次随机选择,以1/3的概率留下数据3,以2/3的概率留下数据1。那么数据1被最终留下的概率是多少呢?
数据1被留下概率:(1/2)* (2/3) = 1/3
数据2被留下概率:(1/2)*(2/3) = 1/3
数据3被留下概率:1/3
这个方法可以满足题目要求,所有数据被留下返回的概率一样。
因此,循着这个思路,我们可以总结算法的过程:
假设需要采样的数量为 k k k。
首先构建一个可容纳 k k k 个元素的数组,将序列的前 k k k 个元素放入数组中。
然后对于第 j j j ( j > k j>k j>k)个元素开始,以 k j \frac{k}{j} jk 的概率来决定该元素是否被替换到数组中(数组中的 k k k个元素被替换的概率是相同的)。 当遍历完所有元素之后,数组中剩下的元素即为所需采取的样本。
二、证明
对于第
i
i
i 个数(
i
≤
k
i≤k
i≤k)。在
k
k
k 步之前,被选中的概率为 1。当走到第
k
+
1
k+1
k+1 步时,被
k
+
1
k+1
k+1 个元素替换的概率 = 第
k
+
1
k+1
k+1 个元素被选中的概率 *
i
i
i 被选中替换的概率,即为
k
k
+
1
×
1
k
=
1
k
+
1
\frac{k}{k + 1} \times \frac{1}{k} = \frac{1}{k + 1}
k+1k×k1=k+11。则不被第
k
+
1
k+1
k+1个元素替换的概率为
1
−
1
k
+
1
=
k
k
+
1
1 - \frac{1}{k + 1} = \frac{k}{k + 1}
1−k+11=k+1k。依次类推,不被
k
+
2
k+2
k+2 个元素替换的概率为
1
−
k
k
+
2
×
1
k
=
k
+
1
k
+
2
1 - \frac{k}{k + 2} \times \frac{1}{k} = \frac{k + 1}{k + 2}
1−k+2k×k1=k+2k+1。则运行到第
n
n
n 步时,第
i
i
i 个数仍保留的概率 = 被选中的概率 * 不被替换的概率,即:
1
×
k
k
+
1
×
k
+
1
k
+
2
×
k
+
2
k
+
3
×
…
×
n
−
1
n
=
k
n
1 \times \frac{k}{k + 1} \times \frac{k + 1}{k + 2} \times \frac{k + 2}{k + 3} \times … \times \frac{n - 1}{n} = \frac{k}{n}
1×k+1k×k+2k+1×k+3k+2×…×nn−1=nk
对于第
j
j
j 个数(
j
>
k
j>k
j>k)。我们知道,在第
j
j
j 步被选中的概率为
k
j
\frac{k}{j}
jk。不被
j
+
1
j+1
j+1 个元素替换的概率为
1
−
k
j
+
1
×
1
k
=
j
j
+
1
1 - \frac{k}{j + 1} \times \frac{1}{k} = \frac{j}{j + 1}
1−j+1k×k1=j+1j。则运行到第
n
n
n 步时,被保留的概率 = 被选中的概率 * 不被替换的概率,即:
k
j
×
j
j
+
1
×
j
+
1
j
+
2
×
j
+
2
j
+
3
×
.
.
.
×
n
−
1
n
=
k
n
\frac{k}{j} \times \frac{j}{j + 1} \times \frac{j + 1}{j + 2} \times \frac{j + 2}{j + 3} \times ... \times \frac{n - 1}{n} = \frac{k}{n}
jk×j+1j×j+2j+1×j+3j+2×...×nn−1=nk
所以对于其中每个元素,被保留的概率都为 k n \frac{k}{n} nk.
三、代码
来自参考文献【1】
public class ReservoirSamplingTest {
private int[] pool; // 所有数据
private final int N = 100000; // 数据规模
private Random random = new Random();
@Before
public void setUp() throws Exception {
// 初始化
pool = new int[N];
for (int i = 0; i < N; i++) {
pool[i] = i;
}
}
private int[] sampling(int K) {
int[] result = new int[K];
for (int i = 0; i < K; i++) { // 前 K 个元素直接放入数组中
result[i] = pool[i];
}
for (int i = K; i < N; i++) { // K + 1 个元素开始进行概率采样
int r = random.nextInt(i + 1);
// 这里其实就是k/j的体现
if (r < K) {
result[r] = pool[i];
}
}
return result;
}
@Test
public void test() throws Exception {
for (int i : sampling(100)) {
System.out.println(i);
}
}
}