215. 数组中的第K个最大元素 - 力扣(LeetCode)
排序
既然是排序后的第k个最大的元素,那自然就会想到排序了…
class Solution {
public:
int findKthLargest(vector<int>& nums, int k) {
sort(nums.begin(), nums.end());
return nums[nums.size()-k];
}
};
自己写快排:
class Solution {
public:
int partition(vector<int> &nums, int l, int r){
int n = rand()%(r-l+1) + l;
swap(nums[n], nums[r]);//随机交换元素,避免极端情况
int pivot = nums[r];
int i = l;
for(int j = l; j < r; ++j){
if(nums[j] < pivot)
swap(nums[i++], nums[j]);
}
swap(nums[i], nums[r]);
return i;
}
void q_sort(vector<int> &nums, int l, int r){
if(l >= r) return;
auto q = partition(nums, l, r);
q_sort(nums, l, q-1);
q_sort(nums, q+1, r);
}
int findKthLargest(vector<int>& nums, int k) {
q_sort(nums, 0, nums.size()-1);
return nums[nums.size()-k];
}
};
但是呢,排序的时间复杂度至少是O(nlogn),还有更好的办法吗?
堆/优先级队列
其实求top K是堆这种数据结构的典型应用了:
我们可以维护一个大小为K的小顶堆(堆顶元素为堆中最小值),当新进的元素比堆顶元素小时,就什么也不做,当比堆顶元素大时,就删除堆顶元素,将该元素入堆。当数组遍历完成之后,堆顶元素就是我们需要的第K个最大的元素,遍历数组的复杂度为O(n),每个元素入堆需要堆化,复杂度为O(logK),所以总的时间复杂度为O(nlogK)。
而堆可以用c++里面的优先级队列(默认是大顶堆,我们需要小顶堆)代替:
class Solution {
public:
int findKthLargest(vector<int>& nums, int k) {
priority_queue<int, vector<int>, greater<int>> q; //使用小顶堆
for(const auto &c : nums){
q.push(c);
if(q.size() > k) q.pop();
}
return q.top();
}
};
上面的代码可以进行优化,如果当前元素比堆顶元素小,就没有push的必要了:
class Solution {
public:
int findKthLargest(vector<int>& nums, int k) {
priority_queue<int, vector<int>, greater<int>> q;
for(int i = 0; i < k; ++i) q.push(nums[i]);
for(int i = k; i < nums.size(); ++i){
if(!q.empty() && nums[i] < q.top()) continue;
else{
q.pop();
q.push(nums[i]);
}
}
return q.top();
}
};
很多时候堆可能需要自己手动实现:
建堆、堆化等讲解数据结构与算法之美:28 | 堆和堆排序:为什么说堆排序没有快速排序快?
其实大顶堆也可以,建一个大顶堆,然后进行k-1次删除堆顶元素之后,堆顶元素就是第k大元素了,而且我们可以利用数组原地建堆(当然也可以选择新建一个数组一个元素一个元素插入这样建堆):
注意数组下标是从0开始的,可以直接操作,也可以将有效数据挪到下标为1处开始再操作,具体挪动方法:
nums.push_back(INT_MIN);
swap(*nums.begin(), *(nums.end()-1));
class Solution {
public:
void buildHeap(vector<int> &nums, int n){
for(int i = n/2; i >= 1; --i){ //从第一个非叶子节点开始
heapify(nums, n, i);
}
}
void heapify(vector<int> &nums, int n, int i){
while(true){
int maxPos = i;
if(2*i <= n && nums[i] < nums[2*i]) maxPos = 2*i;
if(2*i+1 <= n && nums[maxPos] < nums[2*i+1]) maxPos = 2*i+1;
if(i == maxPos) break;
swap(nums[i], nums[maxPos]);
i = maxPos;
}
}
int findKthLargest(vector<int>& nums, int k) {
nums.push_back(INT_MIN);
swap(*nums.begin(), *(nums.end()-1));
buildHeap(nums, nums.size()-1);
//进行k-1次删除堆顶元素操作
int heapsize = nums.size()-1;
for(int i = 1; i < k; ++i){
swap(nums[1], nums[nums.size()-i]);
--heapsize;
heapify(nums, heapsize, 1);
}
return nums[1];
}
};
不过增加一个元素让下标从1开始是不必要的,直接原地建堆一个元素都不用添加也是一样的思路,注意下标的处理就可以了:
class Solution {
public:
//完全二叉树节点数为n, 最后一个叶子节点下标为n-1
//最后一个非叶子节点下标为(n-1)/2
void buildHeap(vector<int> &nums, int n){
for(int i = (n-1)/2; i >= 0; --i){ //从第一个非叶子节点开始(父节点下标n-1)/2)
heapify(nums, n, i);
}
}
void heapify(vector<int> &nums, int n, int i){
while(true){
int maxPos = i;
if(2*i+1 <= n && nums[i] < nums[2*i+1]) maxPos = 2*i+1;
if(2*i+2 <= n && nums[maxPos] < nums[2*i+2]) maxPos = 2*i+2;
if(i == maxPos) break;
swap(nums[i], nums[maxPos]);
i = maxPos;
}
}
int findKthLargest(vector<int>& nums, int k) {
buildHeap(nums, nums.size()-1);
//进行k-1次删除堆顶元素操作
int heapsize = nums.size()-1;
for(int i = 1; i < k; ++i){
swap(nums[0], nums[nums.size()-i]);
--heapsize;
heapify(nums, heapsize, 0);
}
return nums[0];
}
};
快速排序partition
还有一种思路就是快速排序也可以在O(n)时间内查找第K大的元素,基于快速排序划分的思想:
选择数组区间 A[0…n-1]的最后一个元素 A[n-1]作为 pivot,对数组 A[0…n-1]原地分区,这样数组就分成了三部分,A[0…p-1]、A[p]、A[p+1…n-1]:
- 如果 p+1=K,那 A[p]就是要求解的元素;
- 如果 K>p+1, 说明第 K 大元素出现在 A[p+1…n-1]区间,我们再按照上面的思路递归地在 A[p+1…n-1]这个区间内查找。
- 如果 K<p+1,那我们就在 A[0…p-1]区间查找。
具体看代码:
class Solution {
public:
int partition(vector<int> &a, int low, int high){
auto pivot = a[high];
int i = low, j = low;
while(j < high){
if(a[j] > pivot){ //这儿选择将比pivot大的放在左边
swap(a[j], a[i]);
++i;
}
++j;
}
swap(a[i], a[high]);
return i;
}
int quickselect(vector<int> &a, int low, int high, int k){
auto q = partition(a, low, high);
if(q+1 == k) return a[q];
else return q+1 > k ? quickselect(a, low, q-1, k) : quickselect(a, q+1, high, k);
}
int findKthLargest(vector<int>& nums, int k) {
return quickselect(nums, 0, nums.size()-1, k);
}
};
但是,这样的代码虽然时间复杂度为O(n),空间复杂度为O(logn)(递归栈),执行效率(200ms)却不一定块,原因是我们每次都选择区间的最后一个元素作为pivot,在遇到特殊测试用例的时候,时间复杂度会变得很高(比如顺序数组与倒序数组,时间复杂度会上升到O(n^2),此时递归树画出来是链表)。这也是快速排序里经常强调的最坏情况,避免方法其实也比较简单,我们可以随机初始化 pivot 元素或者采用三数取中法等等。
随机初始化 pivot 元素
我们可以随机选择一个数组里的元素与数组末尾的元素进行交换,以达到随机pivot的目的:
class Solution {
public:
inline int randomPartition(vector<int>& a, int low, int high) {
int i = rand() % (high - low + 1) + low; //随机一个i,i在区间[low, high]
swap(a[i], a[high]);
return partition(a, low, high);
}
inline int partition(vector<int> &a, int low, int high){
auto pivot = a[high];
int i = low, j = low;
while(j < high){
if(a[j] > pivot){ //这儿选择将比pivot大的放在左边
swap(a[j], a[i]);
++i;
}
++j;
}
swap(a[i], a[high]);
return i;
}
int quickselect(vector<int> &a, int low, int high, int k){
auto q = randomPartition(a, low, high);
if(q+1 == k) return a[q];
else return q+1 > k ? quickselect(a, low, q-1, k) : quickselect(a, q+1, high, k);
}
int findKthLargest(vector<int>& nums, int k) {
srand(time(0));
return quickselect(nums, 0, nums.size()-1, k);
}
};
这样带返回值的快排其实不太容易理解,还是写成常见的版本吧:
class Solution {
public:
int idx, k;
int partition(vector<int> &nums, int l, int r){
int n = rand()%(r-l+1) + l;
swap(nums[n], nums[r]);//随机交换元素,避免极端情况
int pivot = nums[r];
int i = l;
for(int j = l; j < r; ++j){
if(nums[j] < pivot)
swap(nums[i++], nums[j]);
}
swap(nums[i], nums[r]);
return i;
}
void q_sort(vector<int> &nums, int l, int r){
//注意这儿改为l > r而不是 l >= r,是为了只有一个元素的时候
//也会判断下标和nums.size()-k的关系
if(l > r) return;
auto q = partition(nums, l, r);
if(q == nums.size()-k){
idx = q;
return;
}
q_sort(nums, l, q-1);
q_sort(nums, q+1, r);
}
int findKthLargest(vector<int>& nums, int k) {
this->k = k;
q_sort(nums, 0, nums.size()-1);
return nums[idx];
}
};
三数取中法
快速排序的多种实现方式_qq_32523711的博客-CSDN博客_快速排序的多种方法
class Solution {
public:
void getMid(vector<int> &a, int low, int high){
auto mid = (high + low) / 2;
if (a[low] > a[mid]) swap(a[low], a[mid]);
if (a[low] > a[high]) swap(a[low], a[high]);
if (a[mid] > a[high]) swap(a[mid], a[high]);
swap(a[mid], a[high-1]);
}
int partition(vector<int> &a, int low, int high){
getMid(a, low, high);
auto pivot = a[high-1];
//i, j一定从low开始,如果从low+1开始只有两个元素时会出错
int i = low, j = low;//i指向以处理区间右边界,j用来遍历元素
while (j < high-1){
if (a[j] < pivot) //小于的放左边
swap(a[i++], a[j]);//已处理区间多了一个元素,右边界增加1
++j;
}
swap(a[high-1], a[i]);
return i;
}
int quickselect(vector<int> &a, int low, int high, int k){
if(low = high) return a[low]; //区间只有一个元素的特殊情况
auto q = partition(a, low, high);
if(a.size()-q == k) return a[q]; //这儿换个方式处理
else return a.size()-q > k ? quickselect(a, q+1, high, k) : quickselect(a, low, q-1, k);
}
int findKthLargest(vector<int>& nums, int k) {
return quickselect(nums, 0, nums.size()-1, k);
}
};