寻找数组第k大元素和上篇选择排序算法相似,只是分割之后舍弃另一半数据。
#include <cassert>
#include <iostream>
using namespace std;
//
// 交换 数组中两个元素
//
void swap(int *data, int a, int b)
{
int tmp = *(data + a);
*(data + a) = *(data + b);
*(data + b) = tmp;
}
//
// 对数组区间进行分割,返回支点位置
//
int partition(int *data, int begin, int end)
{
// 选择区间的第一个点作为 支点,暂存在变量pivot中
int pivot = *(data + begin);
// 接下来,扫描区间 从支点之后第一个元素到最后一个元素
int i = begin + 1;
int j = end;
while (true) {
// 从左向右扫描,如果元素值小于支点就向后移动
while (i <= j && *(data+i) < pivot) i++; // 注意:i <= j 而不是 i < j, 否则扫描区间不完整
// 从右向左扫描,如果元素值大于支点就向后移动
while (i <= j && *(data+j) >= pivot) j--; // 注意:i <= j 而不是 i < j, 否则扫描区间不完整
if (i > j) {
// 此时,扫描应该结束,i, j 位置是 j + 1 == i, j 在 i 前
break;
} else {
// 把 从左向右 找到的 大于支点的元素 和
// 从右向左 找到的 小于支点的元素 交换
swap(data, i, j);
}
}
// 扫描结束后,
// j 指向了 < pivot 区间最后一个元素 (区间元素个数是0时, 也成立),
// i 指向了 >= pivot 区间第一个元素 (区间元素个数是0时, 也成立),
// 因支点是取的是区间的第一个元素(位于 < pivot 区间一侧, 若支点取最后元素, 应取i作为分割点),
// 所以应该 把j 位置的值与pivot交换,j 也就是pivot点
swap(data, begin, j);
return j;
}
//
// 找第 nth 小元素 (nth 从 0 开始)
//
int _find_nth_min(int *data, int begin, int end, int nth)
{
int p = partition(data, begin, end);
if (p == nth) {
return data[p];
}
else if (p < nth) {
return _find_nth_min(data, p + 1, end, nth);
} else { // p >= nth
return _find_nth_min(data, begin, p - 1, nth);
}
}
//
// 找第 nth 小元素 wrapper 函数(nth 从 0 开始)
//
int find_nth_min(int *data, int len, int nth)
{
assert(0 <= nth && nth <= len - 1); // 确认范围正确
return _find_nth_min(data, 0, len - 1, nth);
}
//
// 找第 nth 大元素 (nth 从 0 开始)
//
int find_nth_max(int *data, int len, int nth)
{
return find_nth_min(data, len, len - 1 - nth);
}
void print_array(int *data, int len)
{
for (int i = 0; i < len - 1; i++)
{
cout << data[i] << ", ";
}
cout << data[len - 1] << endl;
}
int main()
{
int data[] { 3, 2, 5, 4, 9, 1 };
// 查找前,数组
print_array(data, sizeof(data)/sizeof(data[0]));
// 找第 nth 大元素,nth 从 0 开始
int nth { 2 }; // 指定第几大, 从 0 开始
int res = find_nth_max(data, sizeof(data)/sizeof(data[0]), nth);
cout << "Nth Max Element (starting from 0) : " << res << endl;
// 查找后,数组不需要是有序的
print_array(data, sizeof(data)/sizeof(data[0]));
return 0;
}