寻找最大的前K个数
选择问题的特殊情况是找最大者或最小者,这当然很简单了。还是一个特例找中位数。
利用快速排序中的partition操作
经过partition后,pivot左边的序列sa都大于pivot右边的序列sb;
如果|sa|==K或者|sa|==K-1(当|sa|==K-1,则是左边的sa加上partition返回的那个元素组成K个数),则数组的前K个元素就是最大的前K个元素,算法终止;
如果|sa|<K-1,则从sb中寻找前K-|sa|-1大的元素;
如果|sa|>K,则从sa中寻找前K大的元素。
一次partition(arr,begin,end)操作的复杂度为end-begin,也就是O(N),最坏情况下一次partition操作只找到第1大的那个元素,则需要进行K次partition操作,总的复杂度为O(N*K)。平均情况下每次partition都把序列均分两半,需要log(2,K)次partition操作,总的复杂度为O(N*log(2,K))。
|
在选择pivot的时候传统的做法是选第1个元素作为pivot,一种优化的方法是随机选,更好的方法是三元取中法,更更好的方法是取五分化中项的中项,即把序列分为M组,每组5个元素,对每个组进行组内排序得到中项,然后对M个组按中项进行排序,取中间那个组的中项作为pivot。
利用小根堆实现
顺序读取数组中的前K个元素,构建小根堆。小根堆的特点是根元素最小,并且一次调整(deleteMin)操作的时间复杂度为log(2,K)。
接下来从数组中取下一个元素,如果该元素不比堆顶元素大,则丢弃;否则用它替换堆顶元素,然后调整小根堆。
当把数组中的元素全部读出来后,小根堆中保留的就是前K大的元素。
初始建堆操作需要K*log(2,K)--这是最多的操作次数,从数组中读取后N-K个元素和堆顶元素一一比较,最坏的情况是每次都要替换堆顶元素,都要调整小根堆,复杂度为(N-K)*log(2,K)。总的复杂度为O(N*log(2,K))。
|
用partition方法时需要知道总元素个数N,且内存中要能够容下这么多元素,而使用堆就完全没有这些限制了,堆的大小是K,内存中只要能容下这个堆就可以了。
此题多数互联网公司都有提及,这里简单描述一下。
首先,被问到这题应该先询问数据规模与数据分布。如果数据规模比较小,在千数量级,采用O(nlgn)排序取前K个即可。如果数据为整形,且分布范围不大,可以考虑计数排序,在线性时间中求解。
其次,如果不是上面讨论的情况,就是大规模一般情况。数据集可能在10亿个整形数中取最大的1W个。10亿个整形数全部装入内存大概需要4G空间。
以下采用两种方法:
1、快排方法,快排采用分治思想,每次把数组分成两部分,所以这里关键就是找到第K大的数的那次划分,前一部分数组就是我们需要的。
2、堆方法,1方法的不足是需要把所有数据装入内存,如果内存空间不足,系统颠簸,性能必然下降。如果取最大的K个数,可以先用前K个数建立一个最小堆,然后每次读入一个之后的数据与堆顶元素比较,如果比堆顶元素大则替换,并且heapify维护堆性质。
C/C++源码:
代码通过宏定义QUICK来切换快排方法与堆方法
其中堆方法不需要把所有数据读入内存,但这里为了屏蔽从文件读数据的时间影响,采用先把数据都读入再处理方法,并且为了验证结果正确,结果都把前K个数排序。
#include <iostream>
#include <stack>
#include <cassert>
#include <cstring>
using namespace std;
const int num_per_line = 10;
int string_to_num(const char* str){
int len = strlen(str), sum=0;
for(int i=0;i<len;i++)
sum = sum*10 + str[i]-'0';
return sum;
}
int deal_opt(string& in,string& out,int& n, int& k, int argc, char *argv[]){
for(int i=1;i<argc;i++){
if(!strncmp("-i",argv[i],2) && i<argc-1){
in = argv[i+1];
i++;
}else if(!strncmp("-o",argv[i],2) && i<argc-1){
out = argv[i+1];
i++;
}else if(!strncmp("-n",argv[i],2) && i<argc-1){
n = string_to_num(argv[i+1]);
i++;
}else if(!strncmp("-k",argv[i],2) && i<argc-1){
k = string_to_num(argv[i+1]);
i++;
}else
return 1;
}
assert(n>=k);
return 0;
}
void swap(int& a,int& b){
int tmp=a;
a=b;
b=tmp;
}
//create min heap
void heapify(int A[], int i, int n){
#define LC(i) (2*i+1)
#define RC(i) (2*i+2)
while(i<n/2){
int min = A[i], f=0;
if(min>A[LC(i)] && LC(i)<n) min=A[LC(i)],f=1;
if(min>A[RC(i)] && RC(i)<n) f=2;
if(1==f){
swap(A[i],A[LC(i)]);
i=LC(i);
}else if(2==f){
swap(A[i],A[RC(i)]);
i=RC(i);
}else
break;
}
}
void heapSort(int A[],int n){
if(n<=1) return;
for(int i=n/2-1;i>=0;i--)
heapify(A,i,n);
swap(A[0],A[n-1]);
for(int i=n-2;i>=1;i--){
heapify(A,0,i+1);
swap(A[0],A[i]);
}
}
struct node{
node(int a,int b):l(a),r(b){}
int l,r;
};
int main(int argc, char *argv[]){
int n=(1<<30), k=10;
string infile("din.txt"), outfile("dout.txt");
string help("command [-i infile] [-o outfile] [-n arrNum] [-k firstKBigest]");
if(0==deal_opt(infile,outfile,n,k,argc,argv)){
int *arr = new int[n];
FILE* fptr=NULL;
if((fptr=fopen(infile.c_str(),"r"))!=NULL){
int data, i=0;
clock_t s = clock();
while(i<n && (fscanf(fptr,"%d",&data))!=EOF)
arr[i++]=data;
n=i;
assert(n>=k);
fclose(fptr);
double timeUsed = (double)(clock()-s)/CLOCKS_PER_SEC;
cout << "input timeUsed is " << timeUsed << "s" << endl;
#ifndef QUICK
s = clock();
for(int i=k/2-1;i>=0;i--)//create min heap
heapify(arr,i,k);
for(int i=k+1;i<n;i++){
if(arr[i]>arr[0]){
arr[0]=arr[i];
heapify(arr,0,k);
}
}
timeUsed = (double)(clock()-s)/CLOCKS_PER_SEC;
cout << "K min heap timeUsed is " << timeUsed << "s" <<endl;
s=clock();
heapSort(arr,k);
timeUsed = (double)(clock()-s)/CLOCKS_PER_SEC;
cout << "sort timeUsed is " << timeUsed << "s" << endl;
#else
s = clock();
//quick from max to min
stack<node*> st;
st.push(new node(0,n-1));
while(!st.empty()){
int l=st.top()->l;
int r=st.top()->r;
delete st.top();
st.pop();
int i=l,j=r, pivot=arr[i];
while(i<j){
while(arr[i]>=pivot && i<=j) i++;
while(arr[j]<=pivot && i<=j) j--;
if(i<j) swap(arr[i],arr[j]);
}
swap(arr[l],arr[j]);
if(j<k){
st.push(new node(j+1,r));
}else if(j>k){
st.push(new node(l,j-1));
}else
break;
}
timeUsed = (double)(clock()-s)/CLOCKS_PER_SEC;
cout << "quickSort way timeUsed is " << timeUsed << "s" <<endl;
s=clock();
heapSort(arr,k);
timeUsed = (double)(clock()-s)/CLOCKS_PER_SEC;
cout << "sort timeUsed is " << timeUsed << "s" << endl;
#endif
//
if((fptr=fopen(outfile.c_str(),"w"))!=NULL){
for(int i=0;i<k;i++){
fprintf(fptr,"%d/t",arr[i]);
if(i%num_per_line==num_per_line-1)
fprintf(fptr,"/n");
}
fclose(fptr);
}else{
cout << "incorrect open outfile" << endl;
cout << help << endl;
}
}else{
cout << "incorrect open infile" << endl;
cout << help << endl;
}
}else{
cout << "incorrect options" << endl;
cout << help << endl;
}
return 0;
}
运行结果:
input为从文件读数据时间,第二项是真正操作时间,第三项是为了验证结果的排序时间。
1000W数据规模寻找最大的1W个
10亿数据规模寻找最大的1W个