Reservoir Sampling 蓄水池抽样
1.问题描述
从包含n个元素的列表L(比如链表、输入流等)中随机选取k个样本。
- 其中n是一个很大或者未知的数量,也不能把n个元素都放到内存中,所以无法提前知道n。
- k个样本要随机选取。
- 只能遍历一次链表。
2.解决方案
1)分析问题
- k个样本要随机选取。一种可行的随机方案是n个元素中每个元素被选取的概率都相等,所以每个样本被选中的概率 = k/n
只能遍历一次链表L。假设当前已遍历的链表长度为i,如果对于任意的i>=k,满足每个样本被选取的概率是k/i;那么遍历至链表末尾时(i == n),每个样本被选取的概率就是 k/n.
问题的关键就在于:对于任意的i>=k,每个样本被选取的概率是k/i。
2)解决问题
- a) 首先定义一个大小为k的数组R(蓄水池),用于保存k个抽样结果。
- b) 然后将链表的前k个元素放入蓄水池R中,并令i := k。
此时,蓄水池中每个样本被选中的概率 = k/i = 1. - c) 接着从第k+1个元素开始遍历链表(i := k+1 to n),对于每一个链表元素 L[i]:
- i) 首先生成一个[1,i]范围内的随机数 j
- ii) 如果该随机数 j <= k,那么 令 R[j] = L[i]
- iii) 从以上过程可以看出:在访问链表元素L的元素i的时候,蓄水池R中的任意一个元素被L[i]替换的概率 = k/i * 1/k = 1/i
- iv) 也就是说:蓄水池中每个元素保留在蓄水池中(不被替换)的概率 = 1 - 1/i = (i-1)/i
- d) 归纳法证明:对任意的 i >= k,也满足蓄水池中每个样本被选中的概率 = k/i
- i) 当i==k时,显然链表中每个元素都会出现在蓄水池中,每个样本被选中的概率 = k/i = 1;
- ii) 假设当访问链表第i-1个元素时(i>=k+1),每个样本被选中的概率 = k/(i-1)
- iii) 并且已知在访问第i个元素时,蓄水池中每个元素保留下来的概率 = (i-1)/i
- iv) 那么当访问第i个元素时,蓄水池中每个元素被选中的概率 = k/(i-1)*(i-1)/i = k/i
- v) 当i==n时,蓄水池中每个元素被抽样选中的概率 = k/n. 证明完毕.
3.代码实现
注意:编程语言下标从0开始,所以实现和上述下标从1开始的证明过程有略微不同。
C++代码:
#include <iostream>
#include <cstdlib>
#include <ctime>
#include <vector>
#include <list>
#include <string>
std::vector<int> reservoir_sampling(std::list<int> L, int k){
std::vector<int> R;
if(L.empty() || k == 0) return R;
int idx = 0;
auto iter = L.begin();
for(idx = 0; idx < k && iter != L.end(); idx++ , iter++)
R.push_back(*iter);
if(R.size() < k) return std::vector<int>(); //empty vector means sampling failed.
std::srand(std::time(0));
for(idx = k; iter != L.end(); idx++ , iter++){
int j = std::rand()%(idx+1);
if(j <= k-1) R[j] = *iter;
}
return R;
}
int main(int argc, char* argv[]){
int N = 10000;
if(argc >= 2) N = std::stoi(argv[1]);
std::cout << "N = " << N << std::endl;
int K = 100;
if(argc >= 3) K = std::stoi(argv[2]);
std::cout << "K = " << K << std::endl;
std::list<int> L;
for(int i = 1; i <= N; i++) L.push_back(i);
std::vector<int> R = reservoir_sampling(L,K);
auto iter = R.begin();
int idx = 0;
long long sum = 0;
while(iter != R.end()){
std::cout << *iter << " ";
sum += *iter;
iter++;
idx++;
if(idx%10 == 0) std::cout << std::endl;
}
std::cout << std::endl;
std::cout << "SUM of sampling = " << sum << std::endl;
std::cout << "Average of sampling = " << (sum/K) << std::endl;
return 0;
}
4.抽样测试
1)编译
g++ reservoir_sampling.cpp -o reservoir_sampling
2)测试
# random sampling K=10000 numbers from [1,N=1000000]
./reservoir_sampling 1000000 10000
3)结果
# 1st run
SUM of sampling = 4982726215
Average of sampling = 498272
# 2ed run
SUM of sampling = 5015693839
Average of sampling = 501569
# 3rd run
SUM of sampling = 5033167581
Average of sampling = 503316
# 4th run
SUM of sampling = 4959622014
Average of sampling = 495962