寻找第k小元素

n个已排序的序列A[1...n]的中项是这个序列的第[n/2](向上取整)小的元素。最直接的方法是对这个序列进行排序并取出该元素,这个方法需要O(n log n)的时间。

选择算法是找出序列中的第k小的元素,该算法会设置一个阈值,当元素个数小于该值时直接排序找出第k小元素。若不小于阈值,则将n个元素分为[n/5]组,每组5个元素,如果n不是5的倍数,则排除剩余的元素。每组进行排序并取出它们的中项即第3个元素。接着将这些中项序列中的中项元素记为mm,它是通过递归计算得到的。将A中的元素划分成三个数组:A1、A2和A3,其中分别包含小于、等于和大于mm的元素。最后求出第k小的元素出现在三个数组中的哪一个,并根据测试结果,算法或者返回滴k小的元素,或者在A1或A3上递归。

算法:SELECT

输入:n个元素的数组A[1...n]和整数k,1<=k<=n

输出:A中的第k小元素

select(A, low, high, k)

 

p ← high - low + 1
if p < 44 then 将A排序 return A[k]
令q=[p/5](向下取整)。将A分成q组,每组5个元素。如果5不整除p,则排除剩余元素
将q组中的每一组单独排序,找出中项。所有中项的集合为M
mm ← select(M, 1, q, [q/2](向上取整)) {mm为中项集合的中项}
将A[low...high]分成三组
A1 = {a|a<mm}
A2 = {a|a=mm}
A3 = {a|a>mm}
case
    |A1|>=k: return select(A1, 1, |A1|, k)
    |A1|+|A2|>=k: return mm
    |A1|+|A2|<k: return select(A3, 1, |A3|, k-|A1|-|A3|)
end case

 

 

下面是C++实现:

 

#include <iostream>
#include <stack>
#include <cmath>

using std::stack;
using std::cout;
using std::endl;

int Split(int * a, int low, int high) {
	int i = low;
	int x = a[low];
	for (int j = low+1; j <= high; j++) {
		if (a[j] <= x) {
			i ++;
			if (i != j) {
				int temp = a[i];
				a[i] = a[j];
				a[j] = temp;
			}
		}
	}
	int temp = a[low];
	a[low] = a[i];
	a[i] = temp;

	return i;
}

void QuickSort(int * a, int low, int high) {
	if (low >= high) {
		return;
	}

	stack<int> range;
	range.push(low);
	range.push(high);
	while(!range.empty()) {
		high = range.top();
		range.pop();
		low = range.top();
		range.pop();

		int w = Split(a, low, high);

		if (low < w-1) {
			range.push(low);
			range.push(w-1);
		}
		if (high > w+1) {
			range.push(w+1);
			range.push(high);
		}
	}
}

//寻找第k小的元素,但会破坏原数组的顺序
int select(int * A, int low, int high, int k) {
	int result = 0;
	int p = high-low+1;
	if (p < 6/*44*/) {
		QuickSort(A, low, high);
		return A[k-1];
	}
	int q = p / 5;
	int * M = new int [q];
	for (int i = 0; i < q; i++) {
		QuickSort(A, i*5, i*5+4);
		M[i] = A[i*5+2];
	}
	int mm = select(M, 0, q-1, int(ceil(q/2.0)));

	int * A1 = new int [p];
	int * A2 = new int [p];
	int * A3 = new int [p];
	int count1 = 0, count2 = 0, count3 = 0;
	for (int i = low; i <= high; i++) {
		if (A[i] < mm) {
			A1[count1++] = A[i];
		} else if (A[i] == mm) {
			A2[count2++] = A[i];
		} else {
			A3[count3++] = A[i];
		}
	}
	if (count1 >= k) {
		result = select(A1, 0, count1-1, k);
	} else if (count1+count2 >= k) {
		result = mm;
	} else if (count1+count2 < k) {
		result = select(A3, 0, count3-1, k-count1-count2);
	}

	delete [] M;
	delete [] A1;
	delete [] A2;
	delete [] A3;
	return result;
}

int main(void) {
	int a[] = {8, 33, 17, 51, 57, 49, 35, 11, 25, 37, 14, 3, 2, 13, 52, 12, 6, 29, 32, 54, 5, 16, 22, 23, 7};
	int result = select(a, 0, 24, 13);

	cout << "序列:\n";
	for (int i = 0; i < 25; i++) {
		cout << a[i] << "  ";
	}
	cout << endl;
	cout << "的第k小元素为:" << result << endl;

	getchar();
	return 0;
}
 

下面是Java版本:

 

package select;

import java.util.ArrayList;
import java.util.Arrays;

import sort.QuickSort;

public class SelectArray {
	private ArrayList<Integer> array = new ArrayList<Integer>();
	
	public SelectArray(int [] array) {
		this.array.clear();
		for (int i = 0; i < array.length; i++) {
			this.array.add(array[i]);
		}
	}
	
	private int select(int [] A, int low, int high, int k) {
		//QuickSort qs = null;
		int result = 0;  
	    int p = high-low+1;  
	    if (p < 6/*44*/) {  
	        A = new QuickSort(A).getSortedIntArray();  
	        return A[k-1];  
	    }  
	    int q = p / 5;  
	    int [] M = new int [q];  
	    for (int i = 0; i < q; i++) {
	    	int [] t = Arrays.copyOfRange(A, i*5, i*5+4);
	    	t = new QuickSort(t).getSortedIntArray(); 
	        M[i] = t[2];  
	    }  
	    int mm = select(M, 0, q-1, (int)Math.floor(q/2.0));  
	  
	    int [] A1 = new int [p];  
	    int [] A2 = new int [p];  
	    int [] A3 = new int [p];  
	    int count1 = 0, count2 = 0, count3 = 0;  
	    for (int i = low; i <= high; i++) {  
	        if (A[i] < mm) {  
	            A1[count1++] = A[i];  
	        } else if (A[i] == mm) {  
	            A2[count2++] = A[i];  
	        } else {  
	            A3[count3++] = A[i];  
	        }  
	    }  
	    if (count1 >= k) {  
	        result = select(A1, 0, count1-1, k);  
	    } else if (count1+count2 >= k) {  
	    	result = mm;  
	    } else if (count1+count2 < k) {  
	    	result = select(A3, 0, count3-1, k-count1-count2);  
	    }
		return result;  
	}
	
	public int getSelectedElement(int k) {
		int [] A = new int [this.array.size()];
		for (int i = 0; i < A.length; i++) {
			A[i] = this.array.get(i); 
		}
		return select(A, 0, A.length-1, k);
	}
	/**
	 * @param args
	 */
	public static void main(String[] args) {
		// TODO Auto-generated method stub
		int a[] = {8, 33, 17, 51, 57, 49, 35, 11, 25, 37, 14, 3, 2, 13, 52, 12, 6, 29, 32, 54, 5, 16, 22, 23, 7};
		SelectArray sa = new SelectArray(a);
		
		System.out.println("序列:");
		for (int i = 0; i < 25; i++) {  
			System.out.print(a[i] + "  ");  
	    } 
		System.out.println();
		System.out.println("的第k小元素为:" + sa.getSelectedElement(13));
	}
}
 

Python版本如下:

 

#! /usr/bin/env python  
# -*- coding:utf-8 -*-

from math import ceil

class SelectList:
    def __init__(self, l):
        self.array = list()
        for i in l:
            self.array.append(i)

    def select(self, a, low, high, k):
        result = 0
        p = high-low + 1
        
        if p < 6:
            a.sort()
            return a[k-1]
        q = p/5
        M = [0] * q
        for i in range(0, q):
            t = a[i*5:i*5+5]
            t.sort()
            M[i] = t[2]
        mm = self.select(M, 0, q-1, int(ceil(q/2.0)))
        
        a1 = [] 
        a2 = []
        a3 = []
        count1 = 0
        count2 = 0
        count3 = 0
        for i in a:
            if i < mm:
                a1.append(i)
                count1 += 1
            elif i == mm:
                a2.append(i)
                count2 += 1
            else:
                a3.append(i)
                count3 += 1
        
        if count1 >= k:
            result = self.select(a1, 0, count1-1, k)
        elif count1+count2 >= k:
            result = mm
        elif count1+count2 < k:
            result = self.select(a3, 0, count3-1, k-count1-count2)
        return result

    def getSelectedElement(self, k):
        return self.select(self.array, 0, len(self.array)-1, k)

if __name__ == '__main__':
    a = [8, 33, 17, 51, 57, 49, 35, 11, 25, 37, 14, 3, 2, 13, 52, 12, 6, 29, 32, 54, 5, 16, 22, 23, 7]
    sl = SelectList(a)

    print "序列:"
    for i in a:
        print i,
    print
    print "的第k小元素为:", sl.getSelectedElement(13)
 

 

  • 0
    点赞
  • 1
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值