通常,如果需要选取某个数组中第X小(大)的元素,我们可以使用简单选择排序遍历X遍数组,每一遍遍历可以得到一个最小的元素,这样算法的复杂度为 O(n2)。或者,我们可以使用更高效的快速排序算法,排序完成后,直接取第X位置的元素,这样,算法的复杂度为O(nlgn)。
这里,我们学习一种更高效的算法,利用这一算法,其期望的平均算法复杂度为线性复杂度。
该算法的核心思想是利用快速排序中每一部的分割,在将数组分割为2部分后,自然,前半部分的所有元素都不大于key,而后半部分的所有元素都不小于key。这样,优势在于:每次分割后,只需要关注一半的元素,而另一半不需要继续分割与排序。
可以看到,在selectXthMinValue函数中,只需要根据数组中元素的个数来判断应该关注key前面的部分或key后面的部分数组即可。
算法核心代码如下:
#ifndef _RANDOM_SELECT_H_
#define _RANDOM_SELECT_H_
class RandomSelect
{
public:
RandomSelect();
~RandomSelect();
//This function is select the Xth min value from pth to rth index in array a,
//return the index of the Xth min value
int selectXthMinValue(int *a, int xth, int length);
int selectXthMaxValue(int *a, int xth, int length);
private:
int selectXthMinValue(int* a, int p, int r, int xth);
int selectXthMaxValue(int *a, int p, int r, int xth);
//This function is partition the array a, after exec, the array will partition
//as 2 part, first part is p...q, second part is q+1 ... r, return the index value q
int partition(int* a, int p, int r);
bool validInputParams(int *a, int xth, int length);
};
#endif
#include "RandomSelect.h"
RandomSelect::RandomSelect()
{
}
RandomSelect::~RandomSelect()
{
}
int RandomSelect::selectXthMinValue(int *a, int xth, int length)
{
bool rs = validInputParams(a, xth, length);
if (!rs)
return -1;
else
{
return selectXthMinValue(a, 0, length - 1, xth);
}
}
int RandomSelect::selectXthMaxValue(int *a, int xth, int length)
{
bool rs = validInputParams(a, xth, length);
if (!rs)
return -1;
else
{
return selectXthMaxValue(a, 0, length - 1, xth);
}
}
bool RandomSelect::validInputParams(int *a, int xth, int length)
{
bool rs = true;
if (length <= 0)
{
return false;
}
if (xth > length)
{
return false;
}
return rs;
}
int RandomSelect::selectXthMinValue(int* a, int p, int r, int xth)
{
int q = partition(a, p, r);
int k = q - p + 1; //低区的元素个数,包括q
if (k == xth)
{
return q;
}
else if (k < xth)
{
return selectXthMinValue(a, q + 1, r, xth - k);
}
else
{
return selectXthMinValue(a, p, q - 1, xth);
}
}
int RandomSelect::selectXthMaxValue(int *a, int p, int r, int xth)
{
int q = partition(a, p, r);
int k = r - q + 1;
if (k == xth)
{
return q;
}
else if (k < xth)
{
return selectXthMaxValue(a, p, q - 1, xth - k);
}
else
{
return selectXthMaxValue(a, q + 1, r, xth);
}
}
int RandomSelect::partition(int* a, int p, int r)
{
int keyIndex = r;
int key = a[r];
while (p <= r)
{
while (p <= r)
{
if (a[p] <= key)
{
p++;
}
else //swap p pointer and key
{
int temp = a[p];
a[p] = key;
a[keyIndex] = temp;
keyIndex = p;
break;
}
}
while (p <= r)
{
if (a[r] >= key)
{
r--;
}
else //swap r pointer and key
{
int temp = a[r];
a[r] = key;
a[keyIndex] = temp;
keyIndex = r;
break;
}
}
}
return keyIndex;
}