n个已排序的序列A[1...n]的中项是这个序列的第[n/2](向上取整)小的元素。最直接的方法是对这个序列进行排序并取出该元素,这个方法需要O(n log n)的时间。
选择算法是找出序列中的第k小的元素,该算法会设置一个阈值,当元素个数小于该值时直接排序找出第k小元素。若不小于阈值,则将n个元素分为[n/5]组,每组5个元素,如果n不是5的倍数,则排除剩余的元素。每组进行排序并取出它们的中项即第3个元素。接着将这些中项序列中的中项元素记为mm,它是通过递归计算得到的。将A中的元素划分成三个数组:A1、A2和A3,其中分别包含小于、等于和大于mm的元素。最后求出第k小的元素出现在三个数组中的哪一个,并根据测试结果,算法或者返回滴k小的元素,或者在A1或A3上递归。
算法:SELECT
输入:n个元素的数组A[1...n]和整数k,1<=k<=n
输出:A中的第k小元素
select(A, low, high, k)
p ← high - low + 1 if p < 44 then 将A排序 return A[k] 令q=[p/5](向下取整)。将A分成q组,每组5个元素。如果5不整除p,则排除剩余元素 将q组中的每一组单独排序,找出中项。所有中项的集合为M mm ← select(M, 1, q, [q/2](向上取整)) {mm为中项集合的中项} 将A[low...high]分成三组 A1 = {a|a<mm} A2 = {a|a=mm} A3 = {a|a>mm} case |A1|>=k: return select(A1, 1, |A1|, k) |A1|+|A2|>=k: return mm |A1|+|A2|<k: return select(A3, 1, |A3|, k-|A1|-|A3|) end case
下面是C++实现:
#include <iostream>
#include <stack>
#include <cmath>
using std::stack;
using std::cout;
using std::endl;
int Split(int * a, int low, int high) {
int i = low;
int x = a[low];
for (int j = low+1; j <= high; j++) {
if (a[j] <= x) {
i ++;
if (i != j) {
int temp = a[i];
a[i] = a[j];
a[j] = temp;
}
}
}
int temp = a[low];
a[low] = a[i];
a[i] = temp;
return i;
}
void QuickSort(int * a, int low, int high) {
if (low >= high) {
return;
}
stack<int> range;
range.push(low);
range.push(high);
while(!range.empty()) {
high = range.top();
range.pop();
low = range.top();
range.pop();
int w = Split(a, low, high);
if (low < w-1) {
range.push(low);
range.push(w-1);
}
if (high > w+1) {
range.push(w+1);
range.push(high);
}
}
}
//寻找第k小的元素,但会破坏原数组的顺序
int select(int * A, int low, int high, int k) {
int result = 0;
int p = high-low+1;
if (p < 6/*44*/) {
QuickSort(A, low, high);
return A[k-1];
}
int q = p / 5;
int * M = new int [q];
for (int i = 0; i < q; i++) {
QuickSort(A, i*5, i*5+4);
M[i] = A[i*5+2];
}
int mm = select(M, 0, q-1, int(ceil(q/2.0)));
int * A1 = new int [p];
int * A2 = new int [p];
int * A3 = new int [p];
int count1 = 0, count2 = 0, count3 = 0;
for (int i = low; i <= high; i++) {
if (A[i] < mm) {
A1[count1++] = A[i];
} else if (A[i] == mm) {
A2[count2++] = A[i];
} else {
A3[count3++] = A[i];
}
}
if (count1 >= k) {
result = select(A1, 0, count1-1, k);
} else if (count1+count2 >= k) {
result = mm;
} else if (count1+count2 < k) {
result = select(A3, 0, count3-1, k-count1-count2);
}
delete [] M;
delete [] A1;
delete [] A2;
delete [] A3;
return result;
}
int main(void) {
int a[] = {8, 33, 17, 51, 57, 49, 35, 11, 25, 37, 14, 3, 2, 13, 52, 12, 6, 29, 32, 54, 5, 16, 22, 23, 7};
int result = select(a, 0, 24, 13);
cout << "序列:\n";
for (int i = 0; i < 25; i++) {
cout << a[i] << " ";
}
cout << endl;
cout << "的第k小元素为:" << result << endl;
getchar();
return 0;
}
下面是Java版本:
package select;
import java.util.ArrayList;
import java.util.Arrays;
import sort.QuickSort;
public class SelectArray {
private ArrayList<Integer> array = new ArrayList<Integer>();
public SelectArray(int [] array) {
this.array.clear();
for (int i = 0; i < array.length; i++) {
this.array.add(array[i]);
}
}
private int select(int [] A, int low, int high, int k) {
//QuickSort qs = null;
int result = 0;
int p = high-low+1;
if (p < 6/*44*/) {
A = new QuickSort(A).getSortedIntArray();
return A[k-1];
}
int q = p / 5;
int [] M = new int [q];
for (int i = 0; i < q; i++) {
int [] t = Arrays.copyOfRange(A, i*5, i*5+4);
t = new QuickSort(t).getSortedIntArray();
M[i] = t[2];
}
int mm = select(M, 0, q-1, (int)Math.floor(q/2.0));
int [] A1 = new int [p];
int [] A2 = new int [p];
int [] A3 = new int [p];
int count1 = 0, count2 = 0, count3 = 0;
for (int i = low; i <= high; i++) {
if (A[i] < mm) {
A1[count1++] = A[i];
} else if (A[i] == mm) {
A2[count2++] = A[i];
} else {
A3[count3++] = A[i];
}
}
if (count1 >= k) {
result = select(A1, 0, count1-1, k);
} else if (count1+count2 >= k) {
result = mm;
} else if (count1+count2 < k) {
result = select(A3, 0, count3-1, k-count1-count2);
}
return result;
}
public int getSelectedElement(int k) {
int [] A = new int [this.array.size()];
for (int i = 0; i < A.length; i++) {
A[i] = this.array.get(i);
}
return select(A, 0, A.length-1, k);
}
/**
* @param args
*/
public static void main(String[] args) {
// TODO Auto-generated method stub
int a[] = {8, 33, 17, 51, 57, 49, 35, 11, 25, 37, 14, 3, 2, 13, 52, 12, 6, 29, 32, 54, 5, 16, 22, 23, 7};
SelectArray sa = new SelectArray(a);
System.out.println("序列:");
for (int i = 0; i < 25; i++) {
System.out.print(a[i] + " ");
}
System.out.println();
System.out.println("的第k小元素为:" + sa.getSelectedElement(13));
}
}
Python版本如下:
#! /usr/bin/env python
# -*- coding:utf-8 -*-
from math import ceil
class SelectList:
def __init__(self, l):
self.array = list()
for i in l:
self.array.append(i)
def select(self, a, low, high, k):
result = 0
p = high-low + 1
if p < 6:
a.sort()
return a[k-1]
q = p/5
M = [0] * q
for i in range(0, q):
t = a[i*5:i*5+5]
t.sort()
M[i] = t[2]
mm = self.select(M, 0, q-1, int(ceil(q/2.0)))
a1 = []
a2 = []
a3 = []
count1 = 0
count2 = 0
count3 = 0
for i in a:
if i < mm:
a1.append(i)
count1 += 1
elif i == mm:
a2.append(i)
count2 += 1
else:
a3.append(i)
count3 += 1
if count1 >= k:
result = self.select(a1, 0, count1-1, k)
elif count1+count2 >= k:
result = mm
elif count1+count2 < k:
result = self.select(a3, 0, count3-1, k-count1-count2)
return result
def getSelectedElement(self, k):
return self.select(self.array, 0, len(self.array)-1, k)
if __name__ == '__main__':
a = [8, 33, 17, 51, 57, 49, 35, 11, 25, 37, 14, 3, 2, 13, 52, 12, 6, 29, 32, 54, 5, 16, 22, 23, 7]
sl = SelectList(a)
print "序列:"
for i in a:
print i,
print
print "的第k小元素为:", sl.getSelectedElement(13)