题目链接
215. 数组中的第K个最大元素
题目描述
在未排序的数组中找到第k个最大的元素。请注意,你需要找的是数组排序后的第 k 个最大的元素,而不是第 k 个不同的元素。
样例
示例 1:
输入: [3,2,1,5,6,4] 和 k = 2
输出: 5
示例 2:
输入: [3,2,3,1,2,4,5,5,6] 和 k = 4
输出: 4
topK问题有两种主流做法
- 堆
- 减治(quick selection 算法)
算法1:堆
开一个小顶堆,用于维护枚举到 nums[i] 时的前 k 个最大元素
// 直接写 priority_queue<int> pq; 是大顶堆,小顶堆是下面的写法
priority_queue<int, vector<int>, greater<int> > pq;
先把前 k 个数压进去。然后枚举剩余的数,只要堆顶小于当前枚举的数,就先弹出堆顶在压入当前值,保持堆里是前 K 个最大元素。枚举结束后,堆顶就是答案。
时间复杂度稳定 O(NlogK)
代码(c++)
class Solution {
public:
int findKthLargest(vector<int>& nums, int k) {
int n = nums.size();
if(n == 1) return nums[0];
priority_queue<int, vector<int>, greater<int> > pq;
for(int i = 0; i < k; ++i)
pq.push(nums[i]);
for(int i = k; i < n; ++i)
if(pq.top() < nums[i])
{
pq.pop();
pq.push(nums[i]);
}
return pq.top();
}
};
算法2:快速选择算法(减治)
在子区间 [left, right] 中选择第 k 大的数时,完全照搬快排的划分算法:
- 选择一个枢轴(pivot),然后交换到 left 位置
int randIdx = rand() % (right - left + 1) + left; // 随机选择 pivot
swap(nums[randIdx], nums[left]);
2. 使用划分算法确定 pivot 的位置。大于 pivot 的元素移到左边,小于等于 pivot 的元素移到右边。
int pivot = nums[left];
int l = left, r = right;
while(l < r)
{
//...
}
// 一轮 partition 完成
一轮 partition 完成后,pivot 的位置为 l,此时考察 l 和 left + K 的关系,[left, right] 中第 K 个位置的下标是 left + K - 1:
l = left + K - 1: pivot 刚好在 [left, right] 第 k 个位置,找到答案了
l > left + K - 1: [left, l] 中有 l - left + 1 个数字,还要在 [l + 1, right] 中找 K - (l - left + 1) 个
l < left + K - 1: 在 [left, l - 1] 中继续找第 k 大
时间复杂度 平均O(N),最坏
快速选择算法有一个优化:BFPRT算法,也叫中位数的中位数算法,它进一步优化了 pivot 的选取方法,使得最坏时间复杂度也变为
快速选择算法每完成一轮 partition 将数据分成不均匀的两份后,只有其中一份对结果有影响,可以去掉一部分,数据规模就减少一部分,所以也叫减治算法。插入排序,DFS, BFS, 拓扑排序也隐含了减治的思想:每完成一轮,下一轮就可以少考虑一个数据。
代码(c++)
class Solution {
public:
int findKthLargest(vector<int>& nums, int k) {
int n = nums.size();
if(n == 1) return nums[0];
return partition(nums, k, 0, n - 1);
}
private:
int partition(vector<int>& nums, int k, int left, int right)
{
// 在 nums 的 [left .. right] 中找第 k 大
int randIdx = rand() % (right - left + 1) + left; // 随机选择 pivot
swap(nums[randIdx], nums[left]);
int pivot = nums[left];
int l = left, r = right;
while(l < r)
{
while(l < r && nums[r] <= pivot)
--r;
if(l < r)
{
nums[l] = nums[r];
++l;
}
while(l < r && nums[l] > pivot)
++l;
if(l < r)
{
nums[r] = nums[l];
--r;
}
}
nums[l] = pivot;
if(l == left + k - 1)
return nums[l];
else if(l < left + k - 1)
return partition(nums, k - (l - left + 1), l + 1, right);
else
return partition(nums, k, left, l - 1);
}
};
算法3: 归并树
直接套文章 [力扣315] 索引数组&CDQ分治&归并树 里的归并树的思路:已经根据原始数组建好归并树之后,可以
在归并树建树过程中可以顺便得到数组的最大值 right 和最小值 left,答案肯定在 [left, right]。
值域二分:每次从答案范围 [left, right] 中猜一个答案 mid = (left + right) / 2; 然后查询区间 [left, right] 中大于 mid 的个数 cnt
cnt = k - 1: mid 是答案,返回 mid
cnt > k - 1: mid 猜小了,left = mid + 1 继续猜
cnt < k - 1: mid + 1 肯定大了, right = mid 继续猜
时间复杂度
代码(c++)
struct MTNode
{
int start, end;
vector<int> data;
MTNode *left, *right;
MTNode(int s, int e, const vector<int>& nums, MTNode* l=nullptr, MTNode* r=nullptr)
:start(s),end(e),data(nums),left(l),right(r) {}
~MTNode()
{
if(left)
{
delete left;
left = nullptr;
}
if(right)
{
delete right;
right = nullptr;
}
}
};
class MergeTree
{
public:
MergeTree()
{
root = nullptr;
}
~MergeTree()
{
if(root)
{
delete root;
root = nullptr;
}
}
void build(int start, int end, const vector<int>& nums)
{
root = _build(start, end, nums);
}
int query(int i, int j, int k)
{
if(i > j) return 0;
int result = 0;
_query(root, i, j, k, result);
return result;
}
int get(int i)
{
return (root -> data)[i];
}
private:
MTNode *root;
void _query(MTNode* root, int i, int j, int k, int& result)
{
if(root -> start == i && root -> end == j)
{
auto pos = upper_bound((root -> data).begin(), (root -> data).end(), k);
result += (root -> data).end() - pos;
return;
}
int mid = root -> start + (root -> end - root -> start) / 2;
if(j <= mid)
{
_query(root -> left, i, j, k, result);
return;
}
if(i > mid)
{
_query(root -> right, i, j, k, result);
return;
}
_query(root -> left, i, mid, k, result);
_query(root -> right, mid + 1, j, k, result);
}
MTNode* _build(int start, int end, const vector<int>& nums)
{
if(start == end)
{
return new MTNode(start, end, vector<int>({nums[start]}));
}
int mid = start + (end - start) / 2;
MTNode *left = _build(start, mid, nums);
MTNode *right = _build(mid + 1, end, nums);
vector<int> merged((left -> data).size() + (right -> data).size());
merge((left -> data).begin(), (left -> data).end(), (right -> data).begin(), (right -> data).end(), merged.begin());
MTNode *cur = new MTNode(start, end, merged, left, right);
return cur;
}
};
class Solution_3 {
public:
int findKthLargest(vector<int>& nums, int k) {
int n = nums.size();
if(n == 1) return nums[0];
MergeTree mergetree;
mergetree.build(0, n - 1, nums);
int left = mergetree.get(0), right = mergetree.get(n - 1);
while(left < right)
{
int mid = left + (right - left) / 2;
int cnt = mergetree.query(0, n - 1, mid); // > mid 的元素个数
if(cnt == k - 1)
return mid;
else if(cnt > k - 1)
left = mid + 1;
else
right = mid;
}
return left;
}
};
引申: 区间第k大 & 划分树
划分树和归并树都是用线段树作为辅助的,其中划分树的建树是模拟快排过程,自顶向下建树,归并树是模拟归并排序过程,自底向上建树,两种方法含有分治思想。
划分树的基本思想就是对于某个节点对应的区间 [start, end],把它划分成两个子区间,左边区间的数小于等于右边区间的数,左子区间对应划分后的 [start, mid],右子区间对应划分后的 [mid + 1, end],其中 mid = (left + right) / 2。查找的时候通过记录进入左子树的数的个数,确定下一个查找区间,直到区间长度变成1(start = end, 叶子节点),就找到了。
划分树的节点定义
start, end:节点对应的区间
nums[0..end-start]:节点持有的数据,一共 end - start + 1 个,是原始数组排好序之后在 [start, end] 范围的数字,但这里 nums 并没按排好序的顺序存放,而是按原数组的顺序放的。它具体含了哪些数字来自父节点划分的结果
toleft[0..end-start+2]:记录进入左子树的数字的个数
toleft[i] := [0 .. i - 1] 这 i 个数中,进入左子树的数字的个数,有前缀和的思想
toleft[0] = 0
struct PTNode
{
int start, end;
vector<int> nums, toleft;
PTNode *left, *right;
PTNode(int s, int e, PTNode* l=nullptr, PTNode* r=nullptr)
:start(s),end(e),nums(vector<int>(end - start + 1)),toleft(vector<int>(end - start + 2)),left(l),right(r) {}
~PTNode()
{
if(left)
{
delete left;
left = nullptr;
}
if(right)
{
delete right;
right = nullptr;
}
}
};
划分树的建树
过程基本就是模拟快排过程,对于区间 [start, end] 取一个已经排过序的区间中位数,然后把小于中值的点放左边,大于的放右边,等于中位数的需要单独统计,使得进入左子树的数字个数为 [mid - left + 1],确保树是平衡的。划分树建树之前先求取并持有原数组排序后的数组 sorted,节点的中位数直接就取 sorted[mid]。
int median = sorted[mid];
只要当前节点不是叶子节点(start = end),就先建好两个子节点
PTNode *left = new PTNode(root -> start, mid);
PTNode *right = new PTNode(mid + 1, root -> end);
然后执行划分的算法流程,顺序枚举当前节点的所有数据,判断应该划分的方向,往对应的子节点塞进去。如果是塞进了左子树,toleft 加一。
划分树的查询
query(i, j, k) 查询区间 [i, j] 中第 k 大的数
首先确定 [start ..i - 1], [start..j] 中去往左子树的数字个数 tli, tlj。其中 start = i 的情况需要特判。tlj - tli 是 [i..j] 中去往左子树的元素个数,记 tl。
int tli = 0;
if(root -> start != i)
tli = (root -> toleft)[i - root -> start];
int tlj = (root -> toleft)[j - root -> start + 1];
int tl = tlj - tli;
看 k 和 tl 的关系决定子查询,类似于快速选择算法中考察 l 和 left + K 的关系。
tl >= k: 第 k 大在左子树,答案是 query(root -> left, new_i, new_j, k)
tl < k: 第 k 大在右子树,答案是 query(root -> right, new_i, new_j, k - tl)
new_i, new_j 需要画图推导
用划分树 AC 本题的代码
但是跑的效果很差,因为查询只有1次,需要查询次数较多时才能体现划分树的优势
查询次数为M时,时间复杂度 O(NlogN + MlogN),比M次快速选择的 O(MN) 和归并树的
struct PTNode
{
int start, end;
vector<int> nums, toleft;
PTNode *left, *right;
PTNode(int s, int e, PTNode* l=nullptr, PTNode* r=nullptr)
:start(s),end(e),nums(vector<int>(end - start + 1)),toleft(vector<int>(end - start + 2)),left(l),right(r) {}
~PTNode()
{
if(left)
{
delete left;
left = nullptr;
}
if(right)
{
delete right;
right = nullptr;
}
}
};
class PartitionTree
{
public:
PartitionTree()
{
root = nullptr;
sorted = vector<int>();
}
~PartitionTree()
{
if(root)
{
delete root;
root = nullptr;
}
}
void build(int start, int end, const vector<int>& nums)
{
sorted = nums;
sort(sorted.begin(), sorted.end());
root = new PTNode(start, end);
root -> nums = nums;
_build(root);
}
int query(int i, int j, int k)
{
if(i > j) return 0;
return _query(root, i, j, k);
}
private:
PTNode *root;
vector<int> sorted;
int _query(PTNode* root, int i, int j, int k)
{
if(root -> start == root -> end) return (root -> nums)[0];
int tli = 0;
if(root -> start != i)
tli = (root -> toleft)[i - root -> start];
int tlj = (root -> toleft)[j - root -> start + 1];
int tl = tlj - tli;
int new_i, new_j;
if(tl >= k)
{
// 第 k 大在左子
new_i = root -> start + tli;
new_j = new_i + tl - 1;
return _query(root -> left, new_i, new_j, k);
}
else
{
// 第 k 大在右子
int mid = root -> start + (root -> end - root -> start) / 2;
new_i = mid + 1 + i - (root -> start) - tli;
new_j = new_i + j - i - tl;
return _query(root -> right, new_i, new_j, k - tl);
}
}
void _build(PTNode* root)
{
if(root -> start == root -> end)
return;
int mid = root -> start + (root -> end - root -> start) / 2;
int median = sorted[mid];
PTNode *left = new PTNode(root -> start, mid);
PTNode *right = new PTNode(mid + 1, root -> end);
int n = (root -> nums).size();
int median_to_left = mid - root -> start + 1;
for(int i = 0; i < n; ++i)
{
if((root -> nums)[i] < median)
--median_to_left;
}
// 出循环后 median_to_left 为去往左子树中等于中位数的个数
int to_left = 0; // 去往左子树的个数
int idx_left = 0, idx_right = 0;
for(int i = 0; i < n; ++i)
{
int cur = (root -> nums)[i];
if(cur < median || ((cur == median) && median_to_left > 0))
{
(left -> nums)[idx_left] = cur;
++idx_left;
++to_left;
if(cur == median)
--median_to_left;
}
else
{
(right -> nums)[idx_right] = cur;
++idx_right;
}
(root -> toleft)[i + 1] = to_left;
}
_build(left);
_build(right);
root -> left = left;
root -> right = right;
}
};
class Solution_4 {
public:
int findKthLargest(vector<int>& nums, int k) {
int n = nums.size();
if(n == 1) return nums[0];
PartitionTree partitiontree;
partitiontree.build(0, n - 1, nums);
return partitiontree.query(0, n - 1, n + 1 - k);
}
};