import java.util.ArrayList;
import java.util.List;
import java.util.Random;
/**
* 不等概率不放回的抽样类 使用方法:传入你的概率rates,以及需要抽取的样本个数k。假如传入的概率是:[1,2,3,4,5],
* k为2,如果最后选择到的概率是1,3;那么返回的index为0(概率1的index),2(概率3的index)
*
* @author xutaoyang
*
*/
public class UnequalWithoutReplacementKRandom {
static Random rand = new Random();
/**
* 对外接口方法
*
* @param rates
* 概率
* @param k
* 目标样本的个数
* @return 命中的样本的在概率list中的index
*/
public static List<Integer> randKWithoutReplacement(List<Double> rates, int k) {
if (null == rates || rates.isEmpty()) {
throw new RuntimeException("<<UnequalWithoutReplacementKRandom>> the rates list is null or empty");
}
if (k >= rates.size()) {
throw new RuntimeException("<<UnequalWithoutReplacementKRandom>> k is bigger than rates' size");
}
List<Node> nodes = new ArrayList<Node>(rates.size());
for (int index = 0; index < rates.size(); index++) {
nodes.add(new Node(rates.get(index), index));
}
List<Integer> result = new ArrayList<Integer>(k);
List<Node> heap = buildHeap(nodes);
for (int index = 0; index < k; index++) {
result.add(heapPop(heap));
}
return result;
}
private static List<Node> buildHeap(List<Node> nodes) {
List<Node> heap = new ArrayList<Node>(nodes.size() + 1);
heap.add(null);
for (int index = 0; index < nodes.size(); index++) {
heap.add(nodes.get(index));
}
for (int index = heap.size() - 1; index > 1; index--) {
double curTW = heap.get(index >> 1).totalWeight;
heap.get(index >> 1).totalWeight = curTW + heap.get(index).totalWeight;
}
return heap;
}
/** 关于double的计算都用+-x/了,那点误差就让它去吧,性能高很多啊 */
private static int heapPop(List<Node> heap) {
double gas = heap.get(1).totalWeight * rand.nextDouble();
int i = 1;
while (gas > heap.get(i).weight) {
gas = gas - heap.get(i).weight;
i <<= 1;
if (gas > heap.get(i).totalWeight) {
gas = gas - heap.get(i).totalWeight;
i++;
}
}
double weight = heap.get(i).weight;
int value = heap.get(i).value;
heap.get(i).weight = 0;
while (i > 0) {
heap.get(i).totalWeight = heap.get(i).totalWeight - weight;
i >>= 1;
}
return value;
}
private static class Node {
double weight;
int value;
double totalWeight;
public Node(double weight, int value) {
this.weight = weight;
this.value = value;
this.totalWeight = weight;
}
}
}
转载于:https://blog.51cto.com/1234554321/1612069