在《算法导论》第3版习题习题9.3-7提到,设计一个O(n)时间的算法,对于一个给定的包含n个互异元素的集合S和一个正整数 k<=n,该算法能够确定S中最接近中位数的k个元素。
步骤如下:
1: select A数组得到其中位数nmid,其下标为imid
2: 计算A中每个数到中位数的差值作为数组dis, 并拷贝到数组dis_copy
3: select dis_copy数组得到第k小的数nkmid
4: 遍历数组dis, 获取k个值小于等于nkmid的数
代码如下
int partition(int A[], int p, int r)
{
int x = A[r];
int i = p - 1;
for (int j = p; j < r; ++j)
{
if (A[j] <= x)
{
++i;
swap(A[i], A[j]);
}
}
swap(A[i + 1], A[r]);
return i + 1;
}
int select(int A[], int p, int r, int k)
{
assert(p <= r);
assert(k <= r - p + 1);
if (p == r)
return A[p];
int mid = partition(A, p, r);
int count = mid - p + 1;
if (k == count)
return A[mid];
else if (k < count)
return select(A, p, mid - 1, k);
else
return select(A, mid + 1, r, k - count);
}
int* kth_select(int A[], int len, int k)
{
assert(k <= len);
int *dis = new int[len - 1];
int *dis_cpy = new int[len - 1];
int *res = new int[k];
int nmid = select(A, 0, len - 1, len / 2);
int imid = 0;
int count = 0;
for (int i = 0; i < len; ++i)
{
if (A[i] != nmid)
dis[count++] = abs(A[i] - nmid);
else
imid = i;
}
memcpy(dis_cpy, dis, sizeof(int)*(len - 1));
int nkmid = select(dis_cpy, 0, count - 1, k);
delete dis_cpy;
dis_cpy = NULL;
int ik = 0;
for (int i = 0; ik < k && i < count; ++i)
{
if (dis[i] <= nkmid)
{
if (i < imid)
res[ik++] = nmid - dis[i];
else
res[ik++] = nmid + dis[i];
}
}
delete dis;
dis = NULL;
return res;
}