k分位数是大小为n的集合(比如数组)里面的k-1个数,它们把有序的集合分为k个分组,任何两个个分组之间的大小之差的绝对值不超过1(有点类似于平衡二叉树),比如集合{3, 5, 9, 4, 2, 1, 6, 8, 9, 10, 12, 7, 6},排序后为{1, 2, 3, 4, 5, 6, 6, 7, 8, 9, 9, 10, 12},它的4(k = 4)分位数为{4, 6, 9}, 分组后的子集合分别为{1, 2, 3, 4}, {5, 6, 6}, {7, 8, 9}, {9, 10, 12}。要求从集合中找出这k-1个数,并且时间复杂度为O(nlgk)。
在没有进行排序之前,通过简单的计算可以知道上述集合的k-1个数分别位于集合的第4位,第7位和第10位,如果对这k-1个数分别使用Order Statistics算法(算法导论第九章),比如,第一次找出第4小的数,第二次找出第7小的数,第三次找出第10小的数,虽然每次的时间复杂度为O(n),但k-1次则为O(nk),不是O(nlgk)。时间上,要想从k级别降到lgk级别,可以在Order Statistics的基础上使用分治(divide and conquer,算法导论第四章)的思想,每次递归调用(见代码)就把原来的集合规模减半,当小到一个分组的规模时就不再递归,代码及说明如下:
函数说明:partitionBilaterally被orderStatisticsIter调用,而后者又被kthQuantiles调用,所以前两个函数是辅助函数。partitionBilaterally是quick sort里面的分区思想,而orderStatistics则是在一个集合里面找第order(假设参数是order)小的数。
其它说明:代码用模板实现。三个函数里面都有参数stride,读者可以把它看作为1,这是我为了另一个程序重用代码而添加的,所以不必理会。orderStatisticsIter函数没有用递归,而是顺序实现。代码用英文注释,但不难看懂。
代码如下:
1 template <typename T> 2 int partitionBilaterally(T *a, int low, int high, int stride){ 3 int pivot = low; 4 low += stride; 5 T temp = a[pivot]; 6 while(low <= high){ 7 while(a[high] >= temp && low <= high) 8 high -= stride; 9 if(!(low > high)){ 10 a[pivot] = a[high]; 11 pivot = high; 12 high -= stride; 13 } 14 while(a[low] <= temp && low <= high) 15 low += stride; 16 if(!(low > high)){ 17 a[pivot] = a[low]; 18 pivot = low; 19 low += stride; 20 } 21 } 22 a[pivot] = temp; 23 return pivot; 24 }
1 //------------------------------------------------------------------------- 2 // Iterative version, don't forget to update low and 3 // high after comparing k with order 4 //------------------------------------------------------------------------- 5 template <typename T> 6 int orderStatisticsItera(T *a, int low, int high, int order, int stride){ 7 int base = low; 8 int p = partitionBilaterally(a, low, high, stride), k; 9 while((k = (p-base)/stride+1) != order){ 10 if(k < order) 11 low = p + stride; 12 else high = p - stride; 13 p = partitionBilaterally(a, low, high, stride); 14 } 15 return p; 16 }
1 template <typename T> 2 void kthQuantiles(T a[], int low, int high, int k){ 3 if(k == 1) 4 return; 5 int size = high-low+1; 6 int split = k/2; 7 // Parentheses outside the question mark statement must be added 8 int lowerSize = (size/k)*split + (size%k<split ? size%k : split); 9 orderStatisticsItera(a, low, high, lowerSize, 1); 10 kthQuantiles(a, low, low+lowerSize-1, split); 11 cout << a[low+lowerSize-1] << endl; 12 kthQuantiles(a, low+lowerSize, high, k-split); 13 }
测试函数:
1 void testKthQuantiles(){ 2 int a[18] = {9, 5, 2, 4, 31, 16, 7, 4, 3 12, 8, 1, 6, 5, 3, 4, 7, 7, 4}; 4 kthQuantiles(a, 0, 17, 4); 5 }
输出结果:
如有纰漏,敬请指正,如有更好的实现,欢迎交流~!