Q:查找n个无序数(不重复)的最大K个数(考虑N很大时的时间复杂度)?
分析:如果对N个数进行先排序,再输出最大的K个数,则时间复杂度较大为O(nlgn)。所以可以查找数组的第N-K大的数记为pivot,对数组以pivot进行划分,则右边的数均比pivot大。
如果先找到n-k位置上的数,再划分,则复杂度可以降为O(n)。
由于原数组无序,故可以考虑随机选择查找(平均复杂度为O(n));或者BFPRT算法(即中位数的中位数算法),最坏情况下时间复杂度为O(n).
代码如下:
package BinaryTreeDepth;
public class Solution {
public static void main(String args[]){
int data[]={112,3,456,76,98,5,21,23,32,65,67,87,97,819};
insertSort(data,0,data.length-1);
printArray(data);
int k=8;
int []maxK=maxKSearch(data,0,data.length-1,k);
printArray(maxK);
}
//根据给定的K,返回数组中最大的K个数
public static int [] maxKSearch(int[] array,int low,int high,int k){
if(low==high){
return array;
}
int k_id=high-low+1-k;
int maxKId=BFPRTSearch(array,low,high,k_id); //调用BFPRT算法查找第(n-k)大的数的id
int []data=getArray(array,maxKId); //根据给定id,输出其后面的子数组
insertSort(data,0,data.length-1);
return data;
}
//btpst算法找k_id的数
public static int BFPRTSearch(int []a,int l,int h,int k_id){
if(l==h&&l==k_id){
return l;
}
int mid=midSearch(a,l,h);//求中位数的原id
int mid_new=partition(a,l,h,mid);//求中位数的最终id
if(k_id==mid_new){//比较k_id与min_id的大小,确定在哪边递归
return k_id;
}else if(k_id<mid_new){
h=mid_new-1;
}else{
l=mid_new+1;
}
return BFPRTSearch(a,l,h,k_id);
}
//返回中位数的中位数的id
public static int midSearch(int[]a,int l,int h){
if(l==h){
return l;
}
int i=l;
for(;i<h-5;i+=5){ //每个数组的元素个数为5时
insertSort(a,i,i+4);
swap(a,l+(i-l)/5,i+2);
}
if(i<h){
insertSort(a,i,h);
swap(a,l+(i-l)/5,(i+h)/2);
}
if((i-1)/5>0){
return midSearch(a,l,l+(i-l)/5);
}else
return l;
}
public static void insertSort(int []a,int l,int h){
for(int i=l;i<h;i++){
int j=i+1;
int temp=a[j];
while(j>l&&temp<a[j-1]){
a[j]=a[j-1];
j--;
}
a[j]=temp;
}
}
public static int partition(int []a,int l,int h,int mid){ //根据中位数的id对数组进行划分
int pivot=a[mid];
a[mid]=a[l];
while(l<h){
while(l<h&&pivot<a[h]){
h--;
}
a[l]=a[h];
while(l<h&&pivot>a[l]){
l++;
}
a[h]=a[l];
}
a[l]=pivot;
return l;
}
public static int [] getArray(int []a,int id){//输出给定id之后的元素
int len=a.length-id;
int arr[]=new int [len];
for(int i=0;i<len;i++){
arr[i]=a[id];
id++;
}
return arr;
}
public static void swap(int []a,int i,int j){
int temp=a[i];
a[i]=a[j];
a[j]=temp;
}
public static void printArray(int []array){
for(int i=0;i<array.length;i++){
System.out.print(array[i]+"\t");
}
System.out.print("\n");
}
}