世界十大经典算法之一,由Blum 、 Floyd 、 Pratt 、 Rivest 、 Tarjan提出,这个算法也解了我之前的对某问题的疑惑。
说个题外话,曾经面试的时候被问到的是如何在n个数中找出第二大的数,抽象下就是求n个数中第k大(小)的数。当时想到的算法是取前k个排序,保存在数组中,然后遍历后n-k,并将每次遍历的数与数组中的k个值比较并重新排序。时间复杂度o(kn),如果k小还好说,但是k大了就不好办了。而BFPRT则在线性时间解决了这个问题。
算法思想类似与快排,找出pivot,如果pivot == k,说明找到了;
如果pivot < k,就在[pivot, n] 中找出第k - pivot个;
如果pivot > k, 找出前pivot 中第k大。
步骤:
1.按每5个一组分成n/5取上界。(至于这里为什么要用5个网上没有查到资料,有知道的伙计们麻烦告知下啊)
2.将每一组的中位数进行排序。(任意排序算法)
3.递归调用select查找上一步所有中位数的中位数,设为x,偶数的话去较小的中位数。
4.按x来分割n,设小于等于x的个数为i,大于x的个数为n-i。
5.如果k == i,则返回x,如果k < i, 递归在小于x的数查找第k小的数,如果 k > i, 递归在大于x的数中查找第k - x的数。
时间复杂度:T(n) <= T(n/5) + T(7n / 10) + O(n); T(n/5)是在中位数序列中获取中位数, 而第3步后大于x的数最坏情况是3n / 10, 所以递归的时间是T(7n/10),O(n)则是遍历和生成中位数数组的时间。
设T(n) = cn,c可以不是常数,如果c是n的线性关系,则T(n) = O(n^2); 设a是常数,an是遍历时间。
T(n) <= c(n/5) + c(7n/10) + an = c(9n/10) + an;
cn <= c(9/10)n + an; c < 10a; 所以c是常数,因此可证T(n)至少是O(n),且是线性的。
这里引用维基百科的解释:T(n) ≤ T(n/5) + T(7n/10) + O(n)
The O(n) is for the partitioning work (we visited each element a constant number of times,
in order to form them into O(n) groups and take each median in O(1) time).
From this, one can then show that T(n) ≤ c*n*(1 + (9/10) + (9/10)2 + ...) = O(n).
下面看代码
#include <iostream>
#include <cstdlib>
using namespace std;
#define SIZE 20
#define myrand() rand() % SIZE
void bubble(int a[], int start, int end){
for(int i = 0; i <= end-start; i++){
for(int j = start; j < end - i; j++){
if(a[j] > a[j+1]){
int tmp = a[j];
a[j] = a[j+1];
a[j+1] = tmp;
}
}
}
}
int partition(int a[], int start, int end, int x){
for(;start < end; start++){
if(a[start] > x){
while(start < end && a[end] > x)
end--;
if(start != end){
int tmp = a[end];
a[end] = a[start];
a[start] = tmp;
}else
break; //这里一定要加这条语句,否则外部循环start会在+1
}
}
return start - 1;
}
int select(int a[], int start, int end, int k){
int i, s, t;
if(end - start < 5){
bubble(a,start,end);
return a[start+k-1];
}
for(i = 0; i < (end-start+1)/5; i++){
s = start + 5*i;
t = s + 4;
bubble(a,s,t);
int tmp = a[start+i];
a[start+i] = a[s+2];
a[s+2] = tmp;
}
if((end-start+1) % 5 != 0){
s = start + 5*i;
bubble(a,s,end);
int tmp = a[start+i];
a[start+i] = a[(s+end)/2];
a[(s+end)/2] = tmp;
i++;
}
i--;
int x = select(a,start, start+i, (i+1)/2);
i = partition(a,start, end, x);
int j = i - start + 1;
//这里之所以没有加入j == k的判断是因为在partiton时无法将x排在正确的位置使得左边都小于x而右边都大于x只能保证一边>,另一边<=;
if(j >= k)
return select(a, start, i, k);
else
return select(a, i+1, end, k-j);
}
int main(){
clock_t start, end;
srand((int)time(NULL));
int a[SIZE];
int n = 5;
for(int i = 0; i < SIZE; i++){
a[i] = myrand();
cout << a[i] << "\t";
if((i+1)%5 == 0)
cout << endl;
}
start = clock();
cout << "the no " << n << " is: " << select(a, 0, SIZE-1, n) << endl;
end = clock();
cout << "Time: " << (double)(end - start) << endl;
return 0;
}
上述代码就是简单的随机出SIZE个数据并寻找第k小的数,那么其实也就找到了前k小的值。