根据这道题总结一下快速排序和堆排序,再根据这两种方法写这道题。
给定整数数组 nums
和整数 k
,请返回数组中第 k
个最大的元素。
请注意,你需要找的是数组排序后的第 k
个最大的元素,而不是第 k
个不同的元素。
你必须设计并实现时间复杂度为 O(n)
的算法解决此问题。
示例 1:
输入: [3,2,1,5,6,4]
, k = 2
输出: 5
示例 2:
输入: [3,2,3,1,2,4,5,5,6]
, k = 4
输出: 4
提示:
1 <= k <= nums.length <= 105
-104 <= nums[i] <= 104
我们首先给出快速排序的代码,快速排序的思路是先选取一个基准值,然后把小于基准值的放到基准值左边,把大于基准值的放到基准值右边,这样就会变成三部分(基准值左边部分、基准值、基准值右边部分),对基准值左右再递归进行这个步骤。代码分三部分:快速排序辅助分区部分、排序部分和主函数,分区部分就是把比基准值小的放左边,比基准值大的放右边,然后把基准值放中间,排序部分就是递归排序。
#include <iostream>
#include <vector>
#include <utility> // for std::swap
// 快速排序的辅助函数,进行分区
int partition(std::vector<int> &nums, int low, int high) {
// 选择最左侧的元素作为基准值(pivot)
int pivot = nums[low];
int i = low + 1; // i指针用来记录比基准值小的区域的最后一个元素的位置
int j = high; // j指针用来记录比基准值大的区域的第一个元素的位置
// 循环进行分区操作
while(true) {
// 从左向右找,找到大于等于基准值的元素
while (nums[i] < pivot) {
i++;
}
// 从右向左找,找到小于等于基准值的元素
while (nums[j] > pivot) {
j--;
}
if (i < j) {
std::swap(nums[i], nums[j]);
} else {
// 完成分区,左边全是小于等于基准值,右边全是大于等于基准值
break;
}
}
// 交换基准值到分区的中间
std::swap(nums[low], nums[j]);
// 返回基准值的最终位置
return i;
}
// 快速排序的递归函数
void quickSort(std::vector<int> &nums, int low, int high) {
if (low < high) {
// 分区操作
int pivotIndex = partition(nums, low, high);
// 对基准值左边的子序列进行快速排序
quickSort(nums, low, pivotIndex - 1);
// 对基准值右边的子序列进行快速排序
quickSort(nums, pivotIndex + 1, high);
}
}
int main() {
std::vector<int> nums = {10, 7, 8, 9, 1, 5};
int n = nums.size();
quickSort(nums, 0, n - 1);
for (int num : nums) {
std::cout << num << " ";
}
return 0;
}
运行结果(每一步分区的过程)为:
6 7 8 9 1 5 3 3 6 1 10
6 1 3 3 1 5 6 9 8 7 10
5 1 3 3 1 6 6 9 8 7 10
1 1 3 3 5 6 6 9 8 7 10
1 1 3 3 5 6 6 9 8 7 10
1 1 3 3 5 6 6 9 8 7 10
1 1 3 3 5 6 6 7 8 9 10
1 1 3 3 5 6 6 7 8 9 10
1 1 3 3 5 6 6 7 8 9 10
快速排序的时间复杂度是O(nlogn)
。
基于快速排序可以写出做这道题的快速选择方法的代码,与快速排序一样,需要先分区,之后确定了基准值最终所在的位置,然后不需要进行排序操作,只需要知道第k大的元素是在基准值左边还是右边,然后在那个分区找就可以了,也是递归来查找,这个就是在快排的过程中直接找到了,所以不需要进行完整的快排,因此复杂度变低:
#include <iostream>
#include <vector>
#include <utility> // for std::swap
// 快速排序的辅助函数,进行分区
int partition(std::vector<int> &nums, int low, int high) {
// 选择最左侧的元素作为基准值(pivot)
int pivot = nums[low];
int i = low + 1; // i指针用来记录比基准值小的区域的最后一个元素的位置
int j = high; // j指针用来记录比基准值大的区域的第一个元素的位置
// 循环进行分区操作
while(true) {
// 从左向右找,找到大于等于基准值的元素
while (nums[i] < pivot) {
i++;
}
// 从右向左找,找到小于等于基准值的元素
while (nums[j] > pivot) {
j--;
}
if (i < j) {
std::swap(nums[i], nums[j]);
} else {
// 完成分区,左边全是小于等于基准值,右边全是大于等于基准值
break;
}
}
// 交换基准值到分区的中间
std::swap(nums[low], nums[j]);
// 返回基准值的最终位置
return j;
}
// 快速排序的递归函数
int quickSelect(std::vector<int> &nums, int low, int high, int kIndex) {
if (low == high) {
// 当子数组只有一个元素时,返回该元素
return nums[low];
}
int pivotIndex = partition(nums, low, high);
if (kIndex <= pivotIndex) {
// 第k大的元素索引在左侧子数组中
return quickSelect(nums, low, pivotIndex, kIndex);
} else {
// 第k大的元素索引在右侧子数组中
return quickSelect(nums, pivotIndex + 1, high, kIndex);
}
}
int main() {
std::vector<int> nums = {10, 5, 3, 2, 1, 6, 8, 7};
int n = nums.size();
int k = 3;
// 第k大的元素的索引是k-1
int kIndex = k - 1;
int ans = quickSelect(nums, 0, n - 1, n - 1 - kIndex);
std::cout << "The ans is " << ans << std::endl;
return 0;
}
注意,当求第k
大的元素时,传入的是索引k-1
,当求第k
小的元素(第n-k+1
大)时,传入索引n-k
(即n-1-kIndex
)。这个方法时间复杂度是O(n)
。
下面来总结一下堆排序和这道题,我们给出堆排序的代码:
#include <iostream>
#include <vector>
#include <algorithm> // for std::swap
// 自上向下调整堆,保证堆的性质
void heapify(std::vector<int> &nums, int n, int i) {
int largest = i; // 初始时假设当前节点为最大值
int left = 2 * i + 1; // 左子节点
int right = 2 * i + 2; // 右子节点
// 如果左子节点存在且大于当前节点,更新最大值节点
if (left < n && nums[left] > nums[largest]) {
largest = left;
}
// 如果右子节点存在且大于当前节点,更新最大值节点
if (right < n && nums[right] > nums[largest]) {
largest = right;
}
// 如果最大值节点发生了变化,交换当前节点和最大值节点的值,并继续调整
if (largest != i) {
std::swap(nums[i], nums[largest]);
heapify(nums, n, largest);
}
}
// 堆排序
void heapSort(std::vector<int> &nums) {
int n = nums.size();
// 从最后一个非叶子节点开始建堆,即从 (n/2 - 1) 节点开始
for (int i = n / 2 - 1; i >= 0; i--) {
heapify(nums, n, i);
}
// 从最后一个元素开始,交换元素并进行调整堆操作
for (int i = n - 1; i > 0; i--) {
std::swap(nums[0], nums[i]); // 将当前堆的最大值放到数组末尾
heapify(nums, i, 0); // 调整堆,新的堆大小为 i
}
}
int main() {
std::vector<int> nums = {16, 10, 8, 7, 2, 3, 4, 1, 9, 14};
heapSort(nums);
std::cout << "Sorted array: ";
for (int num : nums) {
std::cout << num << " ";
}
return 0;
}
堆排序有三个重要部分:维护堆的性质,建堆,排序。以大根堆为例,这是一颗完全二叉树,父节点的值大于子节点的值,下标为i
的节点的父节点下标是(i - 1) / 2
(整数除法),下标为i
的节点的左孩子下标是i * 2 + 1
,右孩子下标是i * 2 + 2
,因此,假如有n
个元素,那么堆的最后一个非叶子节点的下标是n / 2 - 1
。
- 维护堆的性质,即为保证父节点值大于子节点值,从上而下调整,比如当前
i
节点不满足这个性质,那么交换i
节点和它的左右孩子中最大的那个,然后再判断子节点那里是否满足堆的性质(之所以需要这样是因为如果进行了交换,那么子节点那里可能会发生变化,比如3 6 5 2 4
这个情况,首先3
和6
进行了交换,变成了6 3 5 2 4
,那么3 2 4
那个部分(之前是6 2 4
)就需要再次进行交换)。 - 建堆,即从最后一个非叶子节点开始,自下而上维护堆的性质,直到根节点。
- 堆排序,将当前堆的最大值放到数组末尾,然后把它排除出去,再从根向下进行堆的维护,新的堆的大小为
n-1
,重复这个过程,直到只剩一个元素。
运行结果为:
create heap
16 14 8 9 10 3 4 1 7 2
sort heap for 9 nums 14 10 8 9 2 3 4 1 7 16
sort heap for 8 nums 10 9 8 7 2 3 4 1 14 16
sort heap for 7 nums 9 7 8 1 2 3 4 10 14 16
sort heap for 6 nums 8 7 4 1 2 3 9 10 14 16
sort heap for 5 nums 7 3 4 1 2 8 9 10 14 16
sort heap for 4 nums 4 3 2 1 7 8 9 10 14 16
sort heap for 3 nums 3 1 2 4 7 8 9 10 14 16
sort heap for 2 nums 2 1 3 4 7 8 9 10 14 16
sort heap for 1 nums 1 2 3 4 7 8 9 10 14 16
sorted heap
1 2 3 4 7 8 9 10 14 16
可以看到16, 10, 8, 7, 2, 3, 4, 1, 9, 14
经过建堆过程(自最后一个非叶子节点向上维护堆),变成了16 14 8 9 10 3 4 1 7 2
,然后需要进行堆排序,将16
和14
交换,然后不管16
了,这个时候它是最后一个元素,再从根向下维护堆,得到14 10 8 9 2 3 4 1 7 16
,然后再将14
和7
交换,进行相同的步骤,最后排序成功。
有了堆排序的基础,我们利用堆排序解决数组中的第K
个最大元素的问题,事实上,在堆排序取最大值的过程中,已经体现出来了,在第一次取16
,这就是第1
大的元素,第二次取14
就是第2
大的元素,那么我们想得到第k
大元素的值,只需要设置堆排序的停止条件为i > n - k
,然后这时候的nums[0]
(即根节点值)为第k
大的元素。如果我们想得到第’k’小的元素,那么就取第n-k+1
大的元素。
详细代码如下:
#include <iostream>
#include <vector>
#include <algorithm> // for std::swap
// 自上向下调整堆,保证堆的性质
void heapify(std::vector<int> &nums, int n, int i) {
int largest = i; // 初始时假设当前节点为最大值
int left = 2 * i + 1; // 左子节点
int right = 2 * i + 2; // 右子节点
// 如果左子节点存在且大于当前节点,更新最大值节点
if (left < n && nums[left] > nums[largest]) {
largest = left;
}
// 如果右子节点存在且大于当前节点,更新最大值节点
if (right < n && nums[right] > nums[largest]) {
largest = right;
}
// 如果最大值节点发生了变化,交换当前节点和最大值节点的值,并继续调整
if (largest != i) {
std::swap(nums[i], nums[largest]);
heapify(nums, n, largest);
}
}
// 堆排序取数
int heapSelect(std::vector<int> &nums, int k) {
int n = nums.size();
// 从最后一个非叶子节点开始建堆,即从 (n/2 - 1) 节点开始
for (int i = n / 2 - 1; i >= 0; i--) {
heapify(nums, n, i);
}
std::cout << "create heap" << std::endl;
for (int num : nums) {
std::cout << num << " ";
}
std::cout << "\n";
// 从最后一个元素开始,交换元素并进行调整堆操作
for (int i = n - 1; i > n - k; i--) {
std::swap(nums[0], nums[i]); // 将当前堆的最大值放到数组末尾
heapify(nums, i, 0); // 调整堆,新的堆大小为 i
std::cout << "sort heap for " << i << " nums" << " ";
for (int num : nums) {
std::cout << num << " ";
}
std::cout << "\n";
}
std::cout << "sorted heap" << std::endl;
for (int num : nums) {
std::cout << num << " ";
}
return nums[0];
}
int main() {
std::vector<int> nums = {16, 10, 8, 7, 2, 3, 4, 1, 9, 14};
int n = nums.size();
int k = 4;
int ans = heapSelect(nums, n - k + 1);
std::cout << "ans=" << ans << std::endl;
return 0;
}
运行结果:
create heap
16 14 8 9 10 3 4 1 7 2
sort heap for 9 nums 14 10 8 9 2 3 4 1 7 16
sort heap for 8 nums 10 9 8 7 2 3 4 1 14 16
sort heap for 7 nums 9 7 8 1 2 3 4 10 14 16
sort heap for 6 nums 8 7 4 1 2 3 9 10 14 16
sort heap for 5 nums 7 3 4 1 2 8 9 10 14 16
sort heap for 4 nums 4 3 2 1 7 8 9 10 14 16
sorted heap
4 3 2 1 7 8 9 10 14 16 ans=4
时间复杂度是O(nlogn)
,建堆的复杂度是O(n)
,删除堆顶元素的复杂度是O(klogn)
,所以总共的时间复杂度是O(n+klogn)=O(nlogn)
。