寻找数组中第k小元素

n个已排序的序列A[1...n]的中项是这个序列的第[n/2](向上取整)小的元素。最直接的方法是对这个序列进行排序并取出该元素,这个方法需要O(n log n)的时间。

选择算法是找出序列中的第k小的元素,该算法会设置一个阈值,当元素个数小于该值时直接排序找出第k小元素。若不小于阈值,则将n个元素分为[n/5]组,每组5个元素,如果n不是5的倍数,则排除剩余的元素。每组进行排序并取出它们的中项即第3个元素。接着将这些中项序列中的中项元素记为mm,它是通过递归计算得到的。将A中的元素划分成三个数组:A1、A2和A3,其中分别包含小于、等于和大于mm的元素。最后求出第k小的元素出现在三个数组中的哪一个,并根据测试结果,算法或者返回滴k小的元素,或者在A1或A3上递归。

算法:SELECT

输入:n个元素的数组A[1...n]和整数k,1<=k<=n

输出:A中的第k小元素

select(A, low, high, k)

伪代码   收藏代码
  1. p ← high - low + 1  
  2. if p < 44 then 将A排序 return A[k]  
  3. 令q=[p/5](向下取整)。将A分成q组,每组5个元素。如果5不整除p,则排除剩余元素  
  4. 将q组中的每一组单独排序,找出中项。所有中项的集合为M  
  5. mm ← select(M, 1, q, [q/2](向上取整)) {mm为中项集合的中项}  
  6. 将A[low...high]分成三组  
  7. A1 = {a|a<mm}  
  8. A2 = {a|a=mm}  
  9. A3 = {a|a>mm}  
  10. case  
  11.     |A1|>=k: return select(A1, 1, |A1|, k)  
  12.     |A1|+|A2|>=k: return mm  
  13.     |A1|+|A2|<k: return select(A3, 1, |A3|, k-|A1|-|A3|)  
  14. end case  

 

 

下面是C++实现:

 

Cpp代码   收藏代码
  1. #include <iostream>  
  2. #include <stack>  
  3. #include <cmath>  
  4.   
  5. using std::stack;  
  6. using std::cout;  
  7. using std::endl;  
  8.   
  9. int Split(int * a, int low, int high) {  
  10.     int i = low;  
  11.     int x = a[low];  
  12.     for (int j = low+1; j <= high; j++) {  
  13.         if (a[j] <= x) {  
  14.             i ++;  
  15.             if (i != j) {  
  16.                 int temp = a[i];  
  17.                 a[i] = a[j];  
  18.                 a[j] = temp;  
  19.             }  
  20.         }  
  21.     }  
  22.     int temp = a[low];  
  23.     a[low] = a[i];  
  24.     a[i] = temp;  
  25.   
  26.     return i;  
  27. }  
  28.   
  29. void QuickSort(int * a, int low, int high) {  
  30.     if (low >= high) {  
  31.         return;  
  32.     }  
  33.   
  34.     stack<int> range;  
  35.     range.push(low);  
  36.     range.push(high);  
  37.     while(!range.empty()) {  
  38.         high = range.top();  
  39.         range.pop();  
  40.         low = range.top();  
  41.         range.pop();  
  42.   
  43.         int w = Split(a, low, high);  
  44.   
  45.         if (low < w-1) {  
  46.             range.push(low);  
  47.             range.push(w-1);  
  48.         }  
  49.         if (high > w+1) {  
  50.             range.push(w+1);  
  51.             range.push(high);  
  52.         }  
  53.     }  
  54. }  
  55.   
  56. //寻找第k小的元素,但会破坏原数组的顺序  
  57. int select(int * A, int low, int high, int k) {  
  58.     int result = 0;  
  59.     int p = high-low+1;  
  60.     if (p < 6/*44*/) {  
  61.         QuickSort(A, low, high);  
  62.         return A[k-1];  
  63.     }  
  64.     int q = p / 5;  
  65.     int * M = new int [q];  
  66.     for (int i = 0; i < q; i++) {  
  67.         QuickSort(A, i*5, i*5+4);  
  68.         M[i] = A[i*5+2];  
  69.     }  
  70.     int mm = select(M, 0, q-1, int(ceil(q/2.0)));  
  71.   
  72.     int * A1 = new int [p];  
  73.     int * A2 = new int [p];  
  74.     int * A3 = new int [p];  
  75.     int count1 = 0, count2 = 0, count3 = 0;  
  76.     for (int i = low; i <= high; i++) {  
  77.         if (A[i] < mm) {  
  78.             A1[count1++] = A[i];  
  79.         } else if (A[i] == mm) {  
  80.             A2[count2++] = A[i];  
  81.         } else {  
  82.             A3[count3++] = A[i];  
  83.         }  
  84.     }  
  85.     if (count1 >= k) {  
  86.         result = select(A1, 0, count1-1, k);  
  87.     } else if (count1+count2 >= k) {  
  88.         result = mm;  
  89.     } else if (count1+count2 < k) {  
  90.         result = select(A3, 0, count3-1, k-count1-count2);  
  91.     }  
  92.   
  93.     delete [] M;  
  94.     delete [] A1;  
  95.     delete [] A2;  
  96.     delete [] A3;  
  97.     return result;  
  98. }  
  99.   
  100. int main(void) {  
  101.     int a[] = {8, 33, 17, 51, 57, 49, 35, 11, 25, 37, 14, 3, 2, 13, 52, 12, 6, 29, 32, 54, 5, 16, 22, 23, 7};  
  102.     int result = select(a, 0, 24, 13);  
  103.   
  104.     cout << "序列:\n";  
  105.     for (int i = 0; i < 25; i++) {  
  106.         cout << a[i] << "  ";  
  107.     }  
  108.     cout << endl;  
  109.     cout << "的第k小元素为:" << result << endl;  
  110.   
  111.     getchar();  
  112.     return 0;  
  113. }  
 

下面是Java版本:

Java代码   收藏代码
  1. package select;  
  2.   
  3. import java.util.ArrayList;  
  4. import java.util.Arrays;  
  5.   
  6. import sort.QuickSort;  
  7.   
  8. public class SelectArray {  
  9.     private ArrayList<Integer> array = new ArrayList<Integer>();  
  10.       
  11.     public SelectArray(int [] array) {  
  12.         this.array.clear();  
  13.         for (int i = 0; i < array.length; i++) {  
  14.             this.array.add(array[i]);  
  15.         }  
  16.     }  
  17.       
  18.     private int select(int [] A, int low, int high, int k) {  
  19.         //QuickSort qs = null;  
  20.         int result = 0;    
  21.         int p = high-low+1;    
  22.         if (p < 6/*44*/) {    
  23.             A = new QuickSort(A).getSortedIntArray();    
  24.             return A[k-1];    
  25.         }    
  26.         int q = p / 5;    
  27.         int [] M = new int [q];    
  28.         for (int i = 0; i < q; i++) {  
  29.             int [] t = Arrays.copyOfRange(A, i*5, i*5+4);  
  30.             t = new QuickSort(t).getSortedIntArray();   
  31.             M[i] = t[2];    
  32.         }    
  33.         int mm = select(M, 0, q-1, (int)Math.floor(q/2.0));    
  34.         
  35.         int [] A1 = new int [p];    
  36.         int [] A2 = new int [p];    
  37.         int [] A3 = new int [p];    
  38.         int count1 = 0, count2 = 0, count3 = 0;    
  39.         for (int i = low; i <= high; i++) {    
  40.             if (A[i] < mm) {    
  41.                 A1[count1++] = A[i];    
  42.             } else if (A[i] == mm) {    
  43.                 A2[count2++] = A[i];    
  44.             } else {    
  45.                 A3[count3++] = A[i];    
  46.             }    
  47.         }    
  48.         if (count1 >= k) {    
  49.             result = select(A1, 0, count1-1, k);    
  50.         } else if (count1+count2 >= k) {    
  51.             result = mm;    
  52.         } else if (count1+count2 < k) {    
  53.             result = select(A3, 0, count3-1, k-count1-count2);    
  54.         }  
  55.         return result;    
  56.     }  
  57.       
  58.     public int getSelectedElement(int k) {  
  59.         int [] A = new int [this.array.size()];  
  60.         for (int i = 0; i < A.length; i++) {  
  61.             A[i] = this.array.get(i);   
  62.         }  
  63.         return select(A, 0, A.length-1, k);  
  64.     }  
  65.     /** 
  66.      * @param args 
  67.      */  
  68.     public static void main(String[] args) {  
  69.         // TODO Auto-generated method stub  
  70.         int a[] = {83317515749351125371432135212629325451622237};  
  71.         SelectArray sa = new SelectArray(a);  
  72.           
  73.         System.out.println("序列:");  
  74.         for (int i = 0; i < 25; i++) {    
  75.             System.out.print(a[i] + "  ");    
  76.         }   
  77.         System.out.println();  
  78.         System.out.println("的第k小元素为:" + sa.getSelectedElement(13));  
  79.     }  
  80. }  
 

Python版本如下:

Python代码   收藏代码
  1. #! /usr/bin/env python    
  2. # -*- coding:utf-8 -*-  
  3.   
  4. from math import ceil  
  5.   
  6. class SelectList:  
  7.     def __init__(self, l):  
  8.         self.array = list()  
  9.         for i in l:  
  10.             self.array.append(i)  
  11.   
  12.     def select(self, a, low, high, k):  
  13.         result = 0  
  14.         p = high-low + 1  
  15.           
  16.         if p < 6:  
  17.             a.sort()  
  18.             return a[k-1]  
  19.         q = p/5  
  20.         M = [0] * q  
  21.         for i in range(0, q):  
  22.             t = a[i*5:i*5+5]  
  23.             t.sort()  
  24.             M[i] = t[2]  
  25.         mm = self.select(M, 0, q-1, int(ceil(q/2.0)))  
  26.           
  27.         a1 = []   
  28.         a2 = []  
  29.         a3 = []  
  30.         count1 = 0  
  31.         count2 = 0  
  32.         count3 = 0  
  33.         for i in a:  
  34.             if i < mm:  
  35.                 a1.append(i)  
  36.                 count1 += 1  
  37.             elif i == mm:  
  38.                 a2.append(i)  
  39.                 count2 += 1  
  40.             else:  
  41.                 a3.append(i)  
  42.                 count3 += 1  
  43.           
  44.         if count1 >= k:  
  45.             result = self.select(a1, 0, count1-1, k)  
  46.         elif count1+count2 >= k:  
  47.             result = mm  
  48.         elif count1+count2 < k:  
  49.             result = self.select(a3, 0, count3-1, k-count1-count2)  
  50.         return result  
  51.   
  52.     def getSelectedElement(self, k):  
  53.         return self.select(self.array, 0, len(self.array)-1, k)  
  54.   
  55. if __name__ == '__main__':  
  56.     a = [83317515749351125371432135212629325451622237]  
  57.     sl = SelectList(a)  
  58.   
  59.     print "序列:"  
  60.     for i in a:  
  61.         print i,  
  62.     print  
  63.     print "的第k小元素为:", sl.getSelectedElement(13)  
 
  • 0
    点赞
  • 10
    收藏
    觉得还不错? 一键收藏
  • 1
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值