题目描述
给定整数数组 nums
和整数 k
,请返回数组中第 k
个最大的元素。
请注意,你需要找的是数组排序后的第 k
个最大的元素,而不是第 k
个不同的元素。
你必须设计并实现时间复杂度为 O(n)
的算法解决此问题。
示例 1:
输入: [3,2,1,5,6,4],
k = 2
输出: 5
示例 2:
输入: [3,2,3,1,2,4,5,5,6],
k = 4
输出: 4
思路
前言
-
约定:假设这里数组的长度为 n。
-
题目分析:本题希望我们返回数组排序之后的倒数第 k 个位置。
方法一:基于快速排序的选择方法
class Solution1 {//基于快速排序的选择方法 O(N) O(logn)
public:
int quickselect(vector<int>& nums, int l, int r, int k) {
if (l == r)//左右指针重合,输出第K大的元素
return nums[k];
int partition = nums[l], i = l - 1, j = r + 1;
while (i < j) {
do i++; while (nums[i] < partition);
do j--; while (nums[j] > partition);
if (i < j)
swap(nums[i], nums[j]);
}
if (k <= j)return quickselect(nums, l, j, k);
else return quickselect(nums, j + 1, r, k);
}
int findKthLargest(vector<int>& nums, int k) {
int n = nums.size();
return quickselect(nums, 0, n - 1, n - k);
}
};
复杂度分析
时间复杂度:O(n),如上文所述,证明过程可以参考「《算法导论》9.2:期望为线性的选择算法」。
空间复杂度:O(logn),递归使用栈空间的空间代价的期望为 O(logn)。
方法二:基于堆排序的选择方法
代码
class Solution {//基于堆排序的选择方法,O(n+klogn)=O(nlogn)
public:
void maxHeapify(vector<int>& a, int i, int heapSize) {//调整最大堆
int l = i * 2 + 1, r = i * 2 + 2, largest = i;
if (l < heapSize && a[l] > a[largest]) {//左孩子大于父节点
largest = l;
}
if (r < heapSize && a[r] > a[largest]) {//右孩子大于父节点
largest = r;
}
if (largest != i) {
swap(a[i], a[largest]);//交换元素
maxHeapify(a, largest, heapSize);//递归调整最大堆
}
}
void buildMaxHeap(vector<int>& a, int heapSize) {//建堆
for (int i = heapSize / 2; i >= 0; --i) {
maxHeapify(a, i, heapSize);
}
}
int findKthLargest(vector<int>& nums, int k) {
int heapSize = nums.size();
buildMaxHeap(nums, heapSize);
for (int i = nums.size() - 1; i >= nums.size() - k + 1; --i) {
swap(nums[0], nums[i]);//把最后一个元素放在堆顶
--heapSize;//删除根节点,删除K-1次,得到的堆顶元素值就是第K大的元素
maxHeapify(nums, 0, heapSize);//递归调整堆
}
return nums[0];
}
};
复杂度分析
时间复杂度:O(nlogn),建堆的时间代价是 O(n),删除的总代价是 O(klogn),因为 k<n,故渐进时间复杂为 O(n+klogn)=O(nlogn)。
空间复杂度:O(logn),即递归使用栈空间的空间代价。
整合代码
#include <iostream>
#include <vector>
#include <string>
#include <sstream>
using namespace std;
class Solution {//吴师兄
public:
int findKthLargest(vector<int>& nums, int k) {
// 执行快速排序操作,定位找到下标为 k - 1 的那个元素
vector<int> res = quickSort(nums, 0, nums.size() - 1, k - 1);
return res[k - 1];
}
// 函数传入待排序数组 nums
// 排序区间的左端点 left
// 排序区间的右端点 right
vector<int> quickSort(vector<int>& nums, int left, int right, int index) {
// 调用函数 partition,将 left 和 right 之间的元素划分为左右两部分
int mid = partition(nums, left, right);
// 如果 mid 下标恰巧为 index,那么找到了最小的 k 个数
if (mid == index) {
// 直接返回
return nums;
// 如果 mid 下标大于 index,那么说明需要在左侧元素中去切分
}
else if (mid > index) {
// 对 mid 左侧的元素进行快速排序
return quickSort(nums, left, mid - 1, index);
}
else {
// 对 mid 右侧的元素进行快速排序
return quickSort(nums, mid + 1, right, index);
}
}
int partition(vector<int>& nums, int left, int right) {
// 经典快速排序的写法
// 设置当前区间的第一个元素为基准元素
int pivot = nums[left];
// left 向右移动,right 向左移动,直到 left 和 right 指向同一元素为止
while (left < right) {
// 只有当遇到小于 pivot 的元素时,right 才停止移动
// 此时,right 指向了一个小于 pivot 的元素,这个元素不在它该在的位置上
while (left < right && nums[right] <= pivot) {
// 如果 right 指向的元素是大于 pivot 的,那么
// right 不断的向左移动
right--;
}
// 将此时的 nums[left] 赋值为 nums[right]
// 执行完这个操作,比 pivot 小的这个元素被移动到了左侧
nums[left] = nums[right];
// 只有当遇到大于 pivot left 才停止移动
// 此时,left 指向了一个大于 pivot 的元素,这个元素不在它该在的位置上
while (left < right && nums[left] >= pivot) {
// 如果 left 指向的元素是小于 pivot 的,那么
// left 不断的向右移动
left++;
}
// 将此时的 nums[right] 赋值为 nums[left]
// 执行完这个操作,比 pivot 大的这个元素被移动到了右侧
nums[right] = nums[left];
}
// 此时,left 和 right 相遇,那么需要将此时的元素设置为 pivot
// 这个时候,pivot 的左侧元素都小于它,右侧元素都大于它
nums[left] = pivot;
// 返回 left
return left;
}
};
class Solution1 {//基于快速排序的选择方法 O(N) O(logn)
public:
int quickselect(vector<int>& nums, int l, int r, int k) {
if (l == r)//左右指针重合,输出第K大的元素
return nums[k];
int partition = nums[l], i = l - 1, j = r + 1;
while (i < j) {
do i++; while (nums[i] < partition);
do j--; while (nums[j] > partition);
if (i < j)
swap(nums[i], nums[j]);
}
if (k <= j)return quickselect(nums, l, j, k);
else return quickselect(nums, j + 1, r, k);
}
int findKthLargest(vector<int>& nums, int k) {
int n = nums.size();
return quickselect(nums, 0, n - 1, n - k);
}
};
class Solution2 {//基于堆排序的选择方法,O(n+klogn)=O(nlogn)
public:
void maxHeapify(vector<int>& a, int i, int heapSize) {
int l = i * 2 + 1, r = i * 2 + 2, largest = i;
if (l < heapSize && a[l] > a[largest]) {
largest = l;
}
if (r < heapSize && a[r] > a[largest]) {
largest = r;
}
if (largest != i) {
swap(a[i], a[largest]);
maxHeapify(a, largest, heapSize);
}
}
void buildMaxHeap(vector<int>& a, int heapSize) {
for (int i = heapSize / 2; i >= 0; --i) {
maxHeapify(a, i, heapSize);
}
}
int findKthLargest(vector<int>& nums, int k) {
int heapSize = nums.size();
buildMaxHeap(nums, heapSize);
for (int i = nums.size() - 1; i >= nums.size() - k + 1; --i) {
swap(nums[0], nums[i]);
--heapSize;
maxHeapify(nums, 0, heapSize);
}
return nums[0];
}
};
vector<int> split(string params_str) {
vector<int> p;
while (params_str.find(",") != string::npos) {
int found = params_str.find(",");
p.push_back(stoi(params_str.substr(0, found)));
params_str = params_str.substr(found + 1);
}
p.push_back(stoi(params_str));
return p;
}
//输入:[3,2,1,5,6,4] 和 k = 2,输出:5
int main() {
string input;
getline(cin, input);
input = input.substr(1, input.length() - 2); // Remove square brackets
istringstream iss(input);
vector<int> nums;
string numStr;
while (getline(iss, numStr, ',')) {
nums.push_back(stoi(numStr));
}
int k;
cin >> k;
int ans;
Solution2 a;
ans = a.findKthLargest(nums, k);
cout << ans << endl;
}