《算法导论》第3版9.3讲解了最坏情况为线性时间的选择算法
步骤如下
1: 将输入数组的n个元素划分为 n/5 组,每组5个元素,且至多只有一组由剩下的 n%5 个元素组成。
2: 寻找 n/5 组中每一组的中位数:首先对每组元素(至多为5个)进行插入排序,然后确定每组有序元素的中位数。
3: 对第2步中找出的 n/5 个中位数,递归调用 select 以找出其中位数 num(如果有偶数个中位数,为了方便,约定 num 是较小的中位数)
4: 利用修改过的partition版本,按中位数的中位数 num 对输入数组进行划分。让 count 比划分的低区中的元素数目多1,因此 num 是第 count 小的元素,并且有 n - count 个元素在划分的高区。
5: 如果 k == count,则返回 num。如果 k < count,则在低区递归调用 select 以找出第 k 小的元素。如果 k > count,则在高区递归查找第 k - count 小的元素。
代码如下
int select(int A[], int p, int r, int k);
int insert_sort(int A[], int len)
{
for (int j = 1; j < len; ++j)
{
int key = A[j];
// insert A[j] into the sorted sequence A[0..j-1]
int i = j - 1;
// 注意是 i >= 0 而不是 i > 0
while (i >= 0 && A[i] > key)
{
A[i + 1] = A[i];
--i;
}
A[i + 1] = key;
}
return A[len / 2];
}
int find_median(int A[], int p, int r)
{
int len = r - p + 1;
int *temp = new int[len / 5 + 1];
int start = p;
int end = p;
int j = 0;
for (int i = 0; i < len; ++i)
{
if (i % 5 == 0)
start = start + i;
if ((i + 1) % 5 == 0 || i == len - 1)
{
end = end + i;
int small_median = insert_sort(A, end - start + 1);
temp[j] = small_median;
++j;
}
}
int total_median = select(temp, 0, j - 1, (j - 1) / 2);
delete[] temp;
return total_median;
}
int partition(int A[], int p, int r, int num)
{
for (int i = p; i <= r; i++)
{
if (A[i] == num)
{
swap(A[i], A[r]);
break;
}
}
int x = A[r];
int i = p - 1;
for (int j = p; j < r; ++j)
{
if (A[j] <= x)
{
++i;
swap(A[i], A[j]);
}
}
swap(A[i + 1], A[r]);
return i + 1;
}
int select(int A[], int p, int r, int k)
{
assert(p <= r);
assert(k <= r - p + 1);
if (p == r)
return A[p];
int num = find_median(A, p, r);
int mid = partition(A, p, r, num);
int count = mid - p + 1;
if (k == count)
return A[mid];
else if (k < count)
return select(A, p, mid - 1, k);
else
return select(A, mid + 1, r, k - count);
}