本文属于「征服LeetCode」系列文章之一,这一系列正式开始于2021/08/12。由于LeetCode上部分题目有锁,本系列将至少持续到刷完所有无锁题之日为止;由于LeetCode还在不断地创建新题,这个截止期限可能是永远。这一系列刷题文章中,不仅会讲解多种解题思路及其优化,还将用多种编程语言实现题解,涉及到通用解法时更会归纳总结出相应的算法模板。
为了方便在PC上运行调试、分享代码文件,我还将建立相关的仓库:https://github.com/memcpy0/LeetCode-Conquest。在这一仓库中,你不仅可以看到LeetCode原题链接、题解代码、题解文章链接、同类题目归纳、通用解法总结等,还可以看到原题出现频率和相关企业等重要信息。如果有其他优选题解,还可以一同分享给他人。
由于本系列文章的内容随时可能发生更新变动,欢迎关注和收藏征服LeetCode系列文章目录一文以作备忘。
You are given an array of positive integers w
where w[i]
describes the weight of i
th
index (0-indexed).
We need to call the function pickIndex()
which randomly returns an integer in the range [0, w.length - 1]
. pickIndex()
should return the integer proportional to its weight in the w
array. For example, for w = [1, 3]
, the probability of picking the index 0
is 1 / (1 + 3) = 0.25
(i.e 25%) while the probability of picking the index 1
is 3 / (1 + 3) = 0.75
(i.e 75%).
More formally, the probability of picking index i
is w[i] / sum(w)
.
Example 1:
Input
["Solution","pickIndex"]
[[[1]],[]]
Output
[null,0]
Explanation
Solution solution = new Solution([1]);
solution.pickIndex(); // return 0. Since there is only one single element on the array the only option is to return the first element.
Example 2:
Input
["Solution","pickIndex","pickIndex","pickIndex","pickIndex","pickIndex"]
[[[1,3]],[],[],[],[],[]]
Output
[null,1,1,1,1,0]
Explanation
Solution solution = new Solution([1, 3]);
solution.pickIndex(); // return 1. It's returning the second element (index = 1) that has probability of 3/4.
solution.pickIndex(); // return 1
solution.pickIndex(); // return 1
solution.pickIndex(); // return 1
solution.pickIndex(); // return 0. It's returning the first element (index = 0) that has probability of 1/4.
Since this is a randomization problem, multiple answers are allowed so the following outputs can be considered correct :
[null,1,1,1,1,0]
[null,1,1,1,1,1]
[null,1,1,1,0,0]
[null,1,1,1,0,1]
[null,1,0,1,0,0]
......
and so on.
Constraints:
1 <= w.length <= 10000
1 <= w[i] <= 10^5
pickIndex
will be called at most10000
times.
题意:给定一个正整数数组 w
,其中 w[i]
代表下标 i
的权重(下标从 0
开始),请写一个函数 pickIndex
,它可以随机地获取下标 i
,选取下标 i
的概率与 w[i]
成正比。也就是说,选取下标 i
的概率为 w[i] / sum(w)
。
解法1 前缀和+随机化+二分
这一道题乍一看可能懵了,该怎么做啊?不会啊?不过仔细一想,题意中说,w[i]
代表下标 i
的权重,选择 i
下标的概率与 w[i]
成正比为 w[i] / sum(w)
……这说明什么呢?说明在 pickIndex
调用总次数为
∑
i
=
0
w
.
l
e
n
g
t
h
−
1
w
[
i
]
\sum_{i=0}^{w.length-1} w[i]
∑i=0w.length−1w[i] 时,下标 i
的返回次数应该为权重值 w[i]
次。
于是我们要做的,就是随机生成一个分布符合权重的序列,其中随机数的产生可以使用语言自带的API,比如Java的 Math.random()
、C的 srand(unsigned(time(0))); rand();
、C++ <random>
设施。
由于
1
≤
w
[
i
]
≤
1
0
5
1 \le w[i] \le 10^5
1≤w[i]≤105 且
w
w
w 长度达到了
1
0
4
10^4
104 ,直接构造一个长度为 sum(w)
、每个 i
出现 w[i]
次的数组(以供随机抽取)会MLE。由此,可以使用前缀和数组作为权重分布序列,权重序列的基本单位为
1
1
1 ,代表 1 / sum(w)
的概率。整个算法的步骤是:
- 计算
w
数组的前缀和数组s
(一定是严格升序的),将其看做总长度为sum(w)
、基本单位为1
的数轴; - 接着使用随机函数产生 [ 1 , s u m ( w ) ] [1,\ sum(w)] [1, sum(w)] 范围内的随机数;
- 通过二分查找前缀和数组,即可找到分布位置,从而找到
w
数组中的原始下标值
以数组 w[] = {1, 3, 5}
为例,前缀和数组 s[] = {0, 1, 4, 9}
,随机生成 [1, 2, 3, 4, 5, 6, 7, 8, 9]
中的某个整数,其中生成 1
时对应的是 w[0]
,生成 2 ~ 4
时对应的是 w[1]
,生成 5 ~ 9
时对应的是 w[2]
,从而满足选择 i
下标的概率与 w[i]
成正比这一题目要求。
最后的代码如下所示,Solution
类的构造方法复杂度为
O
(
n
)
O(n)
O(n) ,pickIndex
方法的复杂度为
O
(
log
n
)
O(\log n)
O(logn) ,空间复杂度为
O
(
n
)
O(n)
O(n) :
//C++ version
unsigned seed = chrono::system_clock::now().time_since_epoch().count();
default_random_engine generator(seed);
class Solution {
private:
vector<int> s;
uniform_int_distribution<int> ud;
public:
Solution(vector<int>& w) {
int n = w.size();
s.resize(n + 1);
for (int i = 0; i < n; ++i) s[i + 1] = w[i] + s[i];
ud.param(uniform_int_distribution<>::param_type {1, s.back()});
}
int pickIndex() {
int target = ud(generator), l = 1, r = s.size();
while (l < r) { //找到前缀和数组s中第一个>=target的下标位置
int mid = l + (r - l) / 2;
if (s[mid] >= target) r = mid;
else l = mid + 1;
}
return l - 1; //从s的下标位置转换到w的下标位置
}
};
//执行用时:76 ms, 在所有 C++ 提交中击败了74.71% 的用户
//内存消耗:39.3 MB, 在所有 C++ 提交中击败了76.93% 的用户
Java版本的代码如下所示,不过这里生成的随机数范围为
[
0
,
s
u
m
(
w
)
)
[0, sum(w))
[0,sum(w)) 。以数组 w[] = {1, 3, 5}
为例,前缀和数组 s[] = {0, 1, 4, 9}
,随机生成 [1, 2, 3, 4, 5, 6, 7, 8, 9]
中的某个整数,其中生成 0
时对应的是 w[0]
,生成 1 ~ 3
时对应的是 w[1]
,生成 4 ~ 8
时对应的是 w[2]
。因此二分代码也不一样:
//Java version
class Solution {
private int[] s;
public Solution(int[] w) {
s = new int[w.length + 1];
for (int i = 0; i < w.length; ++i) s[i + 1] = w[i] + s[i];
}
public int pickIndex() {
int n = s.length, l = 1, r = n;
int target = new Random().nextInt(s[n - 1]); //生成一个随机的int值,介于[0,n)区间
while (l < r) { //找到第一个大于target的数
int mid = l + (r - l) / 2;
if (s[mid] > target) r = mid;
else l = mid + 1;
}
return l - 1;
}
}
做完这一题后,建议使用同样的方法做一下497. Random Point in Non-overlapping Rectangles。
解法2 模拟+桶轮询
这一做法主要是针对OJ,不建议用于实际工程。利用OJ对权重分布只做近似检查的特点,可以构造一个最短轮询序列(权重精度保留到小数点一位),并存储二元组 (i, cnt)
,代表下标 i
在最短轮询序列中出现次数为 cnt
。具体来说,步骤如下:
- 取出最小权重
minw
,计算出权重序列之和sum(w)
,于是最小权重代表的概率是minp = minw / sum(w)
; - 求出使
minp * k >= 1
的k
值,从而可以放大所有下标i
的概率w[i] / sum(w)
到大于等于1
; - 使用放大后的概率值作为下标
i
在最短轮询序列中出现的次数cnt
,存储这些二元组,加总所有下标i
对应的cnt
就能得到轮询序列总长度tot
; - 在
pickIndex
方法中,使用桶编号bid
和桶内编号inid
来对w.length
个桶进行轮询:- 访问当前桶
list.get(bid)
,如果桶内编号inid
没有超出当前桶的数量范围cnt
,就自增桶内编号++inid
,再返回当前桶的编号; - 否则
++bid
移动到下一个桶,重置桶内编号inid = 0
,递归调用pickIndex()
查看下一个桶; - 如果当前桶编号
bid
超过w.length
,就再从头开始。
- 访问当前桶
通过使用这一固定的轮询序列,好处是不需要使用随机函数,同时返回的连续序列在长度不短于最小段长度时,总是符合近似权重分布。实际代码如下:
//C++ version
class Solution {
private:
using bucket = pair<int, int>;
vector<bucket> seq;
int bid = 0, inid = 0, tot = 0; //桶编号,桶内编号,最短轮询序列长度
public:
Solution(vector<int>& w) {
int n = w.size();
double sum = 0, minw = DBL_MAX;
for (int i = 0; i < n; ++i) {
sum += w[i];
minw = fmin(minw, w[i]);
}
double minp = minw / sum;
int k = 1.0 / minp + 5;
for (int i = 0; i < n; ++i) {
int cnt = (int)(w[i] / sum * k);
seq.push_back(bucket{i, cnt});
tot += cnt;
}
}
int pickIndex() {
if (bid >= seq.size()) bid = inid = 0;
bucket b = seq[bid];
int id = b.first, cnt = b.second;
if (inid >= cnt) {
++bid;
inid = 0;
return pickIndex();
}
++inid;
return bid;
}
};
//执行用时:64 ms, 在所有 C++ 提交中击败了97.34% 的用户
//内存消耗:39.3 MB, 在所有 C++ 提交中击败了60.87% 的用户