对于n 个已排序的数组 A[1...n],其中项是其中间元素。
如果 n 是奇数,则中项是序列中第 (n+1)/2 个元素;
如果 n 是偶数,则存在两个中间元素,所处的位置分别是 n/2 和 n/2+1,在这种情况下,我们将选择第 n/2 个最小元素。
这样,综合两种情况,中项是第 ⌈n/2⌉ 最小元素。
寻找中项的一个直接的方法是对所有的元素排序并取出中间一个元素。
但是在一个具有 n 个元素的集合中,中项或通常意义上的第 k 小元素,能够在最优线性时间内找到,这个问题也称为选择问题。
其基本思想如下:假设递归算法中,在每个递归调用的划分步骤后,我们丢弃元素的一个固定部分并且对剩余的元素递归,则问题的规模以几何级数递减,也就是在每个调用过程中,问题的规模以一个常因子被减小。
为了具体性,我们假设不管处理什么样的对象,算法丢弃 1/3 并对剩余的 2/3 部分递归,那么在第二次调用中,元素的数目变为 2n/3 个,第三次调用中为 4n/9 个,第四次调用中为 8n/27 个,等等。现在,假定在每次调用中,算法对每个元素耗费的时间不超过一个常数,则耗费在处理所有元素上的全部时间产生一个几何级数
cn + (2/3)cn + (2/3)2cn + ... + (2/3)jcn + ... < 3cn
这正好是选择算法的工作。下面给出的寻找第 k 小元素的算法 Select 以同样的方法运作。
首先,如果元素个数小于预定义的阀值44(阀值可以自己定,这里先定义为44),则算法使用直接的方法计算第 k 小元素。
下一步把 n 个元素划分成 ⌊n/5⌋ 组,每组由5个元素组成,如果 n 不是5的倍数,则排除剩余的元素。
每组进行排序并取出它的中项即第三个元素。接着将这些中项序列中的中项元素记为 mm,它是通过递归计算得到的。
然后将原始数组划分成三个数组:A1, A2 和 A3,其中分别包含小于、等于和大于 mm 的元素。
最后,求出第 k 小的元素出现在三个数组中的哪一个,并根据测试结果,返回第 k 小的元素或者在 A1 或 A3 上递归。
过程 Select
输入 n 个元素的数组 A[1...n] 和整数 k,1 ≤ k ≤ n
输出 A 中的第 k 小元素
算法描述 select(A, low, high, k)
1. n ← high - low + 1
2. if n < 44 then 将 A 排序 return (A[k])
3. 令 q = ⌊n/5⌋。将 A 分成 q 组,每组5个元素。如果5不整除 n ,则排除剩余的元素。
4. 将 q 组中的每一组单独排序,找出中项。所有中项的集合为 M。
5. mm ← select(M, 1, q, ⌈q/2⌉) { mm 为中项集合的中项 }
6. 将 A[low...high] 分成三组
A1 = { a | a < mm }
A2 = { a | a = mm }
A3 = { a | a > mm }
7. case
|A1| ≥ k : return select(A1, 1, |A1|, k)
|A1| + |A2| ≥ k : return mm
|A1| + |A2| < k : return select(A3, 1, |A3|, k - |A1| - |A2|)
8. end case
算法中的数组是以数学角度来描述的,起始索引都是1,我们用程序来实现时需要注意一下
我这里用 MergeSort 来完成此算法中需要的排序操作,当然你也可以用任意其他的排序方法
public static int select(int[] sourceArray, int low, int high, int nthMinIndex)
{
if (nthMinIndex < 0 || low > high)
return -1;
int sourceLength = high - low + 1;
if (nthMinIndex >= sourceLength)
return -1;
if (sourceLength < 44)
{
mergeSort(sourceArray, low, high);
return sourceArray[nthMinIndex];
}
int middleArrayLength = 5;
int middleArrayQuantity = sourceLength / middleArrayLength;
int[] middleValueArray = new int[middleArrayLength];
for (int i = 0; i < middleArrayLength; i++)
{
mergeSort(sourceArray,
i * middleArrayQuantity, (i + 1) * middleArrayQuantity - 1);
int middleIndex = ((i * 2 + 1) * middleArrayQuantity - 1) / 2 +
(((i * 2 + 1) * middleArrayQuantity - 1) % 2 == 0 ? 0 : 1);
middleValueArray[i] = sourceArray[middleIndex];
}
int middleValue = select(middleValueArray, 0, middleArrayLength - 1,
(middleArrayLength - 1) / 2 + ((middleArrayLength - 1) % 2 == 0 ? 0 : 1));
List<Integer> lessThanMiddleValueList = new LinkedList<Integer>();
List<Integer> equalsWithMiddleValueList = new LinkedList<Integer>();
List<Integer> greaterThanMiddleValueList = new LinkedList<Integer>();
for (int i = 0; i < sourceArray.length; i++)
{
if (sourceArray[i] < middleValue)
{
lessThanMiddleValueList.add(sourceArray[i]);
}
else if (sourceArray[i] == middleValue)
{
equalsWithMiddleValueList.add(middleValue);
}
else
{
greaterThanMiddleValueList.add(sourceArray[i]);
}
}
Integer[] lessThanMiddleValueArray = new Integer[lessThanMiddleValueList.size()];
lessThanMiddleValueList.toArray(lessThanMiddleValueArray);
Integer[] greaterThanMiddleValueArray =
new Integer[greaterThanMiddleValueList.size()];
greaterThanMiddleValueList.toArray(greaterThanMiddleValueArray);
if (lessThanMiddleValueList.size() > nthMinIndex)
{
return select(ArrayUtils.toPrimitive(lessThanMiddleValueArray),
0, lessThanMiddleValueList.size() - 1, nthMinIndex);
}
else if (lessThanMiddleValueList.size() + equalsWithMiddleValueList.size()
> nthMinIndex)
{
return middleValue;
}
else
{
return select(ArrayUtils.toPrimitive(greaterThanMiddleValueArray),
0,
greaterThanMiddleValueList.size() - 1,
nthMinIndex
- lessThanMiddleValueList.size() - equalsWithMiddleValueList.size());
}
}