n个已排序的序列A[1...n]的中项是这个序列的第[n/2](向上取整)小的元素。最直接的方法是对这个序列进行排序并取出该元素,这个方法需要O(n log n)的时间。
选择算法是找出序列中的第k小的元素,该算法会设置一个阈值,当元素个数小于该值时直接排序找出第k小元素。若不小于阈值,则将n个元素分为[n/5]组,每组5个元素,如果n不是5的倍数,则排除剩余的元素。每组进行排序并取出它们的中项即第3个元素。接着将这些中项序列中的中项元素记为mm,它是通过递归计算得到的。将A中的元素划分成三个数组:A1、A2和A3,其中分别包含小于、等于和大于mm的元素。最后求出第k小的元素出现在三个数组中的哪一个,并根据测试结果,算法或者返回滴k小的元素,或者在A1或A3上递归。
算法:SELECT
输入:n个元素的数组A[1...n]和整数k,1<=k<=n
输出:A中的第k小元素
select(A, low, high, k)
- p ← high - low + 1
- if p < 44 then 将A排序 return A[k]
- 令q=[p/5](向下取整)。将A分成q组,每组5个元素。如果5不整除p,则排除剩余元素
- 将q组中的每一组单独排序,找出中项。所有中项的集合为M
- mm ← select(M, 1, q, [q/2](向上取整)) {mm为中项集合的中项}
- 将A[low...high]分成三组
- A1 = {a|a<mm}
- A2 = {a|a=mm}
- A3 = {a|a>mm}
- case
- |A1|>=k: return select(A1, 1, |A1|, k)
- |A1|+|A2|>=k: return mm
- |A1|+|A2|<k: return select(A3, 1, |A3|, k-|A1|-|A3|)
- end case
下面是C++实现:
- #include <iostream>
- #include <stack>
- #include <cmath>
- using std::stack;
- using std::cout;
- using std::endl;
- int Split(int * a, int low, int high) {
- int i = low;
- int x = a[low];
- for (int j = low+1; j <= high; j++) {
- if (a[j] <= x) {
- i ++;
- if (i != j) {
- int temp = a[i];
- a[i] = a[j];
- a[j] = temp;
- }
- }
- }
- int temp = a[low];
- a[low] = a[i];
- a[i] = temp;
- return i;
- }
- void QuickSort(int * a, int low, int high) {
- if (low >= high) {
- return;
- }
- stack<int> range;
- range.push(low);
- range.push(high);
- while(!range.empty()) {
- high = range.top();
- range.pop();
- low = range.top();
- range.pop();
- int w = Split(a, low, high);
- if (low < w-1) {
- range.push(low);
- range.push(w-1);
- }
- if (high > w+1) {
- range.push(w+1);
- range.push(high);
- }
- }
- }
- //寻找第k小的元素,但会破坏原数组的顺序
- int select(int * A, int low, int high, int k) {
- int result = 0;
- int p = high-low+1;
- if (p < 6/*44*/) {
- QuickSort(A, low, high);
- return A[k-1];
- }
- int q = p / 5;
- int * M = new int [q];
- for (int i = 0; i < q; i++) {
- QuickSort(A, i*5, i*5+4);
- M[i] = A[i*5+2];
- }
- int mm = select(M, 0, q-1, int(ceil(q/2.0)));
- int * A1 = new int [p];
- int * A2 = new int [p];
- int * A3 = new int [p];
- int count1 = 0, count2 = 0, count3 = 0;
- for (int i = low; i <= high; i++) {
- if (A[i] < mm) {
- A1[count1++] = A[i];
- } else if (A[i] == mm) {
- A2[count2++] = A[i];
- } else {
- A3[count3++] = A[i];
- }
- }
- if (count1 >= k) {
- result = select(A1, 0, count1-1, k);
- } else if (count1+count2 >= k) {
- result = mm;
- } else if (count1+count2 < k) {
- result = select(A3, 0, count3-1, k-count1-count2);
- }
- delete [] M;
- delete [] A1;
- delete [] A2;
- delete [] A3;
- return result;
- }
- int main(void) {
- int a[] = {8, 33, 17, 51, 57, 49, 35, 11, 25, 37, 14, 3, 2, 13, 52, 12, 6, 29, 32, 54, 5, 16, 22, 23, 7};
- int result = select(a, 0, 24, 13);
- cout << "序列:\n";
- for (int i = 0; i < 25; i++) {
- cout << a[i] << " ";
- }
- cout << endl;
- cout << "的第k小元素为:" << result << endl;
- getchar();
- return 0;
- }
下面是Java版本:
- package select;
- import java.util.ArrayList;
- import java.util.Arrays;
- import sort.QuickSort;
- public class SelectArray {
- private ArrayList<Integer> array = new ArrayList<Integer>();
- public SelectArray(int [] array) {
- this.array.clear();
- for (int i = 0; i < array.length; i++) {
- this.array.add(array[i]);
- }
- }
- private int select(int [] A, int low, int high, int k) {
- //QuickSort qs = null;
- int result = 0;
- int p = high-low+1;
- if (p < 6/*44*/) {
- A = new QuickSort(A).getSortedIntArray();
- return A[k-1];
- }
- int q = p / 5;
- int [] M = new int [q];
- for (int i = 0; i < q; i++) {
- int [] t = Arrays.copyOfRange(A, i*5, i*5+4);
- t = new QuickSort(t).getSortedIntArray();
- M[i] = t[2];
- }
- int mm = select(M, 0, q-1, (int)Math.floor(q/2.0));
- int [] A1 = new int [p];
- int [] A2 = new int [p];
- int [] A3 = new int [p];
- int count1 = 0, count2 = 0, count3 = 0;
- for (int i = low; i <= high; i++) {
- if (A[i] < mm) {
- A1[count1++] = A[i];
- } else if (A[i] == mm) {
- A2[count2++] = A[i];
- } else {
- A3[count3++] = A[i];
- }
- }
- if (count1 >= k) {
- result = select(A1, 0, count1-1, k);
- } else if (count1+count2 >= k) {
- result = mm;
- } else if (count1+count2 < k) {
- result = select(A3, 0, count3-1, k-count1-count2);
- }
- return result;
- }
- public int getSelectedElement(int k) {
- int [] A = new int [this.array.size()];
- for (int i = 0; i < A.length; i++) {
- A[i] = this.array.get(i);
- }
- return select(A, 0, A.length-1, k);
- }
- /**
- * @param args
- */
- public static void main(String[] args) {
- // TODO Auto-generated method stub
- int a[] = {8, 33, 17, 51, 57, 49, 35, 11, 25, 37, 14, 3, 2, 13, 52, 12, 6, 29, 32, 54, 5, 16, 22, 23, 7};
- SelectArray sa = new SelectArray(a);
- System.out.println("序列:");
- for (int i = 0; i < 25; i++) {
- System.out.print(a[i] + " ");
- }
- System.out.println();
- System.out.println("的第k小元素为:" + sa.getSelectedElement(13));
- }
- }
Python版本如下:
- #! /usr/bin/env python
- # -*- coding:utf-8 -*-
- from math import ceil
- class SelectList:
- def __init__(self, l):
- self.array = list()
- for i in l:
- self.array.append(i)
- def select(self, a, low, high, k):
- result = 0
- p = high-low + 1
- if p < 6:
- a.sort()
- return a[k-1]
- q = p/5
- M = [0] * q
- for i in range(0, q):
- t = a[i*5:i*5+5]
- t.sort()
- M[i] = t[2]
- mm = self.select(M, 0, q-1, int(ceil(q/2.0)))
- a1 = []
- a2 = []
- a3 = []
- count1 = 0
- count2 = 0
- count3 = 0
- for i in a:
- if i < mm:
- a1.append(i)
- count1 += 1
- elif i == mm:
- a2.append(i)
- count2 += 1
- else:
- a3.append(i)
- count3 += 1
- if count1 >= k:
- result = self.select(a1, 0, count1-1, k)
- elif count1+count2 >= k:
- result = mm
- elif count1+count2 < k:
- result = self.select(a3, 0, count3-1, k-count1-count2)
- return result
- def getSelectedElement(self, k):
- return self.select(self.array, 0, len(self.array)-1, k)
- if __name__ == '__main__':
- a = [8, 33, 17, 51, 57, 49, 35, 11, 25, 37, 14, 3, 2, 13, 52, 12, 6, 29, 32, 54, 5, 16, 22, 23, 7]
- sl = SelectList(a)
- print "序列:"
- for i in a:
- print i,
- print "的第k小元素为:", sl.getSelectedElement(13)