有一个在大数据下很现实的例子:
“给出一个数据流,这个数据流的长度很大或者未知。并且对该数据流中数据只能访问一次。请写出一个随机选择算法,使得数据流中所有数据被选中的概率相等。”
解决这个问题既需要算法设计,又需要一些概率论知识,因此对于大多数人,起码包括我,这不是一个立刻就能想出答案的问题。
解决这个问题的算法叫蓄水池采样(Reservoir Sampling)算法。本篇博客介绍该算法的原理、证明和代码实现。
原理
介绍该算法之前,我们首先从最简单的例子出发:假设数据流只有一个数据。我们接收数据,发现数据流结束了,直接返回该数据,该数据返回的概率为1。看来很简单,那么我们试试难一点的情况:假设数据流里有两个数据。
我们读到了第一个数据,这次我们不能直接返回该数据,因为数据流没有结束。我们继续读取第二个数据,发现数据流结束了。因此我们只要保证以相同的概率返回第一个或者第二个数据就可以满足题目要求。因此我们生成一个0到1的随机数R,如果R小于0.5,我们就返回第一个数据,如果R大于0.5,返回第二个数据。
接着我们继续分析有三个数据的数据流的情况。为了方便,我们按顺序给流中的数据命名为1、2、3。我们陆续收到了数据1、2。和前面的例子一样,我们只能保存一个数据,所以必须淘汰1和2中的一个。应该如何淘汰呢?不妨和上面例子一样,我们按照二分之一的概率淘汰一个,例如我们淘汰了2。继续读取流中的数据3,发现数据流结束了,我们知道在长度为3的数据流中,如果返回数据3的概率为1/3,那么才有可能保证选择的正确性。也就是说,目前我们手里有1、3两个数据,我们通过一次随机选择,以1/3的概率留下数据3,以2/3的概率留下数据1。那么数据1被最终留下的概率是多少呢?
数据1被留下概率:(1/2)* (2/3) = 1/3
数据2被留下概率:(1/2)*(2/3) = 1/3
数据3被留下概率:1/3
这个方法可以满足题目要求,所有数据被留下返回的概率一样。
因此,循着这个思路,我们可以总结算法的过程:
假设需要采样的数量为 k。
首先构建一个可容纳 k 个元素的数组,将序列的前 k 个元素放入数组中。
然后对于第 j(j>k)个元素开始,以 k/j的概率来决定该元素是否被替换到数组中(数组中的k个元素被替换的概率是相同的)。 当遍历完所有元素之后,数组中剩下的元素即为所需采取的样本。
证明
[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-5WNZOAtE-1662029573584)(/upload/2022/04/image-c985edac4ada4ddaa030accd45775db5.png)]
代码
class ReservoirSampling
{
public:
void SetStream(const std::vector<int32_t>& stream)
{
m_stream = stream;
}
std::vector<int32_t> Reservoir(int32_t k)
{
std::vector<int32_t>vec(k, 0);
for (int i = 0; i < k && i<m_stream.size(); ++i)
{
vec[i] = m_stream[i];
}
for (int i = k; i < m_stream.size(); ++i)
{
const auto index = rand() % (i + 1);
//如果这个数被选中,替换掉已有的。
//k/j的体现
if (index < k)
{
vec[index] = m_stream[i];
}
}
return vec;
}
private:
std::vector<int32_t>m_stream;
}
例题
398. 随机数索引
该问题可以理解为蓄水池大小为1
class Solution
{
public:
Solution(vector<int>& nums)
{
m_nums = nums;
}
int pick(int target)
{
int res{-1};
int count{0};
for(int i = 0; i<m_nums.size(); ++i)
{
if(m_nums[i] == target)
{
//count等同于蓄水池算法里的i
++count;
//rand()%count == 0
//相当于蓄水池算法里的index<k
if(rand()%count == 0)res = i;
}
}
return res;
}
private:
std::vector<int> m_nums;
};
/**
* Your Solution object will be instantiated and called as such:
* Solution* obj = new Solution(nums);
* int param_1 = obj->pick(target);
*/
参考
https://blog.csdn.net/anshuai_aw1/article/details/88750673