最坏情况为线性时间的第k大元素

在统计和数据分析中,我们经常会遇到求最大值、最小值、中位数、四分位数、Top K等类似需求,其实它们都属于顺序统计量,本文将对顺序统计量的定义和求解算法进行介绍,重点介绍如何在最差时间复杂度也是线性的情况下求解第k大元素。

1. 顺序统计量与选择问题

在一个有 n n n个元素的集合中,第 i i i顺序统计量是该集合中第 i i i小的元素。例如在集合 ( 1 , 3 , 5 , 2 ) (1, 3, 5, 2) (1,3,5,2)中,第2个顺序统计量为2。

从一个有 n n n个元素的集合中,选择出(求解)其第 i i i个顺序统计量的问题被称为选择问题。选择问题的输入输出如下:

输入:一个包含 n n n个不同的数的集合 A A A和一个数 i i i 1 ≤ i ≤ n 1 \leq i \leq n 1in);
输出:元素 x ∈ A x \in A xA,它恰大于A中其它 ( i − 1 ) (i-1) (i1)个元素。

2. 选择问题的求解方法

显然,对输入集合 A A A进行排序之后就可以解决选择问题,使用堆排序或归并排序对输入集合进行排序,然后在排序后的数组中标出第 i i i个元素,即可在 O ( n log ⁡ n ) O(n\log{n}) O(nlogn)时间内完成求解,但还有更快的算法。

2.1 最大值与最小值

我们先考虑选择问题的特殊情况,只求解最大值或最小值,可以发现很容易在 O ( n ) O(n) O(n)时间内完成求解。只需要遍历数组,进行(n-1)次比较即可。以求解最小值为例,伪码如下:

MINIMUM(A)
min = A[1]
n = length[A]
for(i=2; i<=n; i++)
	if(min > A[i])
		min = A[i]
return min

在此基础上,增加一点难度,我们希望同时找最大值和最小值,是否还可以在 O ( n ) O(n) O(n)时间内完成求解呢?答案是肯定的,在遍历过程中不再每次比较一个元素,而是每次比较两个元素,两个元素中较小的元素与当前最小值比较,较大的元素和当前最大值比较,即每对元素需要3次比较即可。伪码如下:

MAX-MINIMUM(A)
if(length[A] is odd)
	min = A[1]
	max = min
else 
	min = MIN(A[1], A[2])
	max = MAX(A[1], A[2])
i++
while(i <= length[A])
	min = MIN(MIN(A[i], A[i+1]), min)
	max = MAX(MAX(A[i], A[i+1]), max)
	i = i + 2 
return min, max

在当前最大值和最小值初始值的设定上,如果n是奇数,将最大值和最小值均设置为第一个元素的值;如果n是偶数,就对前两个元素做一次比较来决定最大值与最小值的初始值。因此,如果n是奇数,那么总共做了 3 ⌊ n / 2 ⌋ 3\lfloor n/2 \rfloor 3n/2次比较;如果n是偶数,那么总共做了 3 n / 2 − 2 3n/2-2 3n/22次比较,时间复杂度为 O ( n ) O(n) O(n)

2.2 期望时间为线性的选择问题

我们回到一般的选择问题,看起来一般的选择问题比求最小值和最大值复杂得多,但神奇的算法仍然可以让我们在平均为线性的时间完成求解。这里再次使用了分治思想,借鉴了快速排序的随机划分方法,如果刚好划分元的左边有(i-1)个元素,则找到第i小的元素;否则,在划分元的左侧或右侧继续进行随机划分。伪码如下:

RANDOMIZED-SELECT(A, p, r, i) // A为数组,p为数组左边界,r为数组右边界,i为待求的顺序统计量序号
if(p == r) // 临界问题处理
	return A[p]
q = RANDOMIZED-PARTITION(A, p, r) //进行划分,返回划分元下标
k = q – p + 1 // k=rank(A[q]) in A[p,…,r], 返回划分元的序号
if(i == k)
	return A[q]
else if(i < k)
	return RANDOMIZED-SELECT(A, p, q - 1, i)
else
	return RANDOMIZED-SELECT(A, q + 1, r, i – k)

可以证明在平均情况下,算法的时间复杂度为 O ( n ) O(n) O(n)。而当运气不好时,每次都只能去除一个元素,算法的时间复杂度就可能达到 O ( n 2 ) O(n^2) O(n2)

2.3 最差时间为线性的选择问题

在上述RANDOMIZED-SELECT算法的基础上,保证每次对数组的划分是个好划分,我们就能进一步在最差情况下也用线性时间解决选择问题。主要步骤如下:

  1. n个元素每5个分为一组,一共 ⌈ n / 5 ⌉ \lceil n/5 \rceil n/5组。最后一组有n mod 5个元素。
  2. 对每组进行排序,取其中位数。若最后一组有偶数个元素,则取较小的中位数。
  3. 递归地使用本算法寻找 ⌈ n / 5 ⌉ \lceil n/5 \rceil n/5个中位数的中位数x
  4. x作为划分元对数组A进行划分,并设x是第k个最小元。
  5. 如果i = k,则返回x;否则如果i < k,则找左区间的第i个最小元;如果i > k,则找右区间的第i - k个最小元。

伪码如下:

SELECT(A, p, r, i)
if(r - p <= 140)
	用简单的排序算法对数组A[p..r进行排序
	return A[p + k - 1]
n = r - p + 1
for(i = 0; i <= floor(n/5); i++) //寻找每组的中位数
	将A[p+5*i]至A[p+5*i+4]的第3小元素与A[p+i]交换位置
x = SELECT(A, p, p+floor(n/5), floor(n/10)) //找中位数的中位数
i = PARTITION(A, p, r, x)
j = i - p + 1
if(k <= j)
	return SELECT(A, p, i, k)
else 
	return SELECT(A, i + 1, r, k - j)

3. 程序代码

以下C语言程序代码实现了最坏情况为线性的select算法,将“求数组a[1..n]中第k大的元素”转化为“求数组a[1..n]中第(n-k+1)小的元素”。递归调用select时,设置数组长度不大于140时,即直接使用插入排序。

3.1 linearSelect_kth.cpp

#include <stdio.h>
#include <stdlib.h>

#define N 1000000     //定义输入数组的最大长度 
#define LEN 5         //定义select中每组元素的个数 

int a[N];

void swap(int *a, int *b) { //交换 a 与 b 的值 
	int tmp = *a;
    *a = *b;
    *b = tmp;
}

int partition(int a[], int low, int high, int pivot) { //将数组a[low..high]划分为 <= pivot和 > pivot的两部分 
    int x;
    int i = low - 1;
    int j;
    for (j = low; j < high; j++) { //在数组中找到值等于privot的元素作为主元,交换到数组最右端 
        if (a[j] == pivot) {
            swap(&a[j], &a[high]);
        }
    }
    x = a[high];
    for (j = low; j < high; j++) { //维护低区a[low..i] <= x, 高区a[i+1..j-1] > x  
        if (a[j] <= x) {           //如果发现a[j] <= x,则将a[j]交换到低区 
            i++;
            swap(&a[i], &a[j]);
        }
    }
    swap(&a[i + 1], &a[high]);    //将主元与最左的大于 x 的元素a[i+1]交换,此时主元到了它应在的位置 
    return i + 1;                 //返回分区完成后主元所在的新下标 
}

void insertSort(int a[], int low, int high) { //对a[low..high]进行插入排序 
    int i, j;
    for (i = low + 1; i <= high; i++) {
        int temp = a[i];
        for (j = i - 1; j >= low && temp < a[j]; j--) {
            a[j + 1] = a[j];
        }
        a[j + 1] = temp;
    }
}

int select(int a[], int begin, int end, int k) { //选出数组a[begin..end]的第k小元素 
    int length = end - begin + 1;   //数组长度,即数组中元素的个数 
    if (length <= 140) {            //长度较小,直接用插入排序 
        insertSort(a, begin, end);
        return a[begin + k - 1];
    }
    int groups = (length + LEN) / LEN;  //组数 
    int i;
    for (i = 0; i < groups; i++) {
        int left = begin + LEN * i;  //第i组的左边界 
        int right = (begin + LEN * i + LEN - 1) > end ? end : (begin + LEN * i + LEN - 1);  //第i组的右边界 
        insertSort(a, left, right);  //组内进行插入排序 
        //将第i组中位数与数组a[]的第i个元素互换位置,方便递归select寻找中位数的中位数
        int mid = (left + right) / 2;
        swap(&a[begin + i], &a[mid]); 
    }
    int pivot = select(a, begin, begin + groups - 1, (groups + 1) / 2);  //找出中位数的中位数
    int p = partition(a, begin, end, pivot);  //用中位数的中位数作为划分的主元
	int leftNum =  p - begin;                 //低区元素的数量 
    if (k == leftNum + 1) {
    	return a[p];
	}
    else if (k <= leftNum) {
    	return select(a, begin, p - 1, k);  //在低区递归调用select来找出第k小的元素 
	}
    else {
    	return select(a, p + 1, end, k - leftNum -1);  //在高区递归调用select来找出第(k-leftNum-1)小的元素 
	}   
}


int main() {
    FILE *fp = fopen("data_1022.txt","r");            //打开文件 
	if (fp == NULL) {
		printf("Can not open the file!\n");
		exit(0);
	}
	int i = 0;
	while (fscanf(fp, "%d\n", &a[i]) != EOF) {  //读取文件中的数据到数组a[]中 
		i++;
	}
	fclose(fp);                                 //关闭文件 
	int k;
	while (1) {
		printf("Please enter an integer k, and you will get the k-th largest element in the array!\n");
		printf("(Enter negative or zero to quit): ");
		scanf("%d", &k);
		if (k <= 0) {
			printf("Bye\n"); 
			break;
		}
    	printf("The %dth largest element in the array is: %d\n", k, select(a, 0, i - 1, i - k + 1));
    	printf("\n==================================================================================\n");
	}
    return 0;
}

3.2 linearSelect_kth_grouplenth_5_vs_7_vs_3.cpp

然后,在3.1节代码的基础上,我尝试改变每组元素的个数,分别设置每组元素个数为5、7、3,比较算法运行时间的差异。

#include <stdio.h>
#include <stdlib.h>
#include <windows.h>

#define N 1000000     //定义输入数组的最大长度 
#define LEN1 5        //尝试改变select中每组元素的个数
#define LEN2 7
#define LEN3 3

int a[N];

void swap(int *a, int *b) { //交换 a 与 b 的值 
	int tmp = *a;
    *a = *b;
    *b = tmp;
}

int partition(int a[], int low, int high, int pivot) { //将数组a[low..high]划分为 <= pivot和 > pivot的两部分 
    int x;
    int i = low - 1;
    int j;
    for (j = low; j < high; j++) { //在数组中找到值等于privot的元素作为主元,交换到数组最右端 
        if (a[j] == pivot) {
            swap(&a[j], &a[high]);
        }
    }
    x = a[high];
    for (j = low; j < high; j++) { //维护低区a[low..i] <= x, 高区a[i+1..j-1] > x  
        if (a[j] <= x) {           //如果发现a[j] <= x,则将a[j]交换到低区 
            i++;
            swap(&a[i], &a[j]);
        }
    }
    swap(&a[i + 1], &a[high]);    //将主元与最左的大于 x 的元素a[i+1]交换,此时主元到了它应在的位置 
    return i + 1;                 //返回分区完成后主元所在的新下标 
}

void insertSort(int a[], int low, int high) { //对a[low..high]进行插入排序 
    int i, j;
    for (i = low + 1; i <= high; i++) {
        int temp = a[i];
        for (j = i - 1; j >= low && temp < a[j]; j--) {
            a[j + 1] = a[j];
        }
        a[j + 1] = temp;
    }
}

int select_5(int a[], int begin, int end, int k) { //选出数组a[begin..end]的第k小元素,分组长度为5 
    int length = end - begin + 1;  //数组长度,即数组中元素的个数
    if (length <= 140) {           //长度较小,直接用插入排序 
        insertSort(a, begin, end);
        return a[begin + k - 1];
    }
    int groups = (length + LEN1) / LEN1;  //组数
    int i;
    for (i = 0; i < groups; i++) {
        int left = begin + LEN1 * i;  //第i组的左边界 
        int right = (begin + LEN1 * i + LEN1 - 1) > end ? end : (begin + LEN1 * i + LEN1 - 1);  //第i组的右边界 
        insertSort(a, left, right);  //组内进行插入排序 
        //将第i组中位数与数组a[]的第i个元素互换位置,方便递归select寻找中位数的中位数
        int mid = (left + right) / 2;
        swap(&a[begin + i], &a[mid]); 
    }
    int pivot = select_5(a, begin, begin + groups - 1, (groups + 1) / 2);  //找出中位数的中位数
    int p = partition(a, begin, end, pivot);  //用中位数的中位数作为划分的主元
    int leftNum =  p - begin;                 //低区元素的数量 
    if (k == leftNum + 1) {
    	return a[p];
	}
    else if (k <= leftNum) {
    	return select_5(a, begin, p - 1, k);  //在低区递归调用select来找出第k小的元素 
	}
    else {
    	return select_5(a, p + 1, end, k - leftNum -1);  //在高区递归调用select来找出第(k-leftNum-1)小的元素 
	}   
}

int select_7(int a[], int begin, int end, int k) {
    int length = end - begin + 1;
    if (length <= 140) {
        insertSort(a, begin, end);
        return a[begin + k - 1];
    }
    int groups = (length + LEN2) / LEN2;
    int i;
    for (i = 0; i < groups; i++) {
        int left = begin + LEN2 * i;  //第i组的左边界 
        int right = (begin + LEN2 * i + LEN2 - 1) > end ? end : (begin + LEN2 * i + LEN2 - 1);  //第i组的右边界 
        insertSort(a, left, right);  //组内进行插入排序 
        //将第i组中位数与数组a[]的第i个元素互换位置,方便递归select寻找中位数的中位数
        int mid = (left + right) / 2;
        swap(&a[begin + i], &a[mid]); 
    }
    int pivot = select_7(a, begin, begin + groups - 1, (groups + 1) / 2);  //找出中位数的中位数
    int p = partition(a, begin, end, pivot);  //用中位数的中位数作为划分的主元
    int leftNum =  p - begin;                 //低区元素的数量 
    if (k == leftNum + 1) {
    	return a[p];
	}
    else if (k <= leftNum) {
    	return select_7(a, begin, p - 1, k);  //在低区递归调用select来找出第k小的元素 
	}
    else {
    	return select_7(a, p + 1, end, k - leftNum -1);  //在高区递归调用select来找出第(k-leftNum-1)小的元素 
	}   
}

int select_3(int a[], int begin, int end, int k) {
    int length = end - begin + 1;
    if (length <= 140) {
        insertSort(a, begin, end);
        return a[begin + k - 1];
    }
    int groups = (length + LEN3) / LEN3;
    int i;
    for (i = 0; i < groups; i++) {
        int left = begin + LEN3 * i;  //第i组的左边界 
        int right = (begin + LEN3 * i + LEN3 - 1) > end ? end : (begin + LEN3 * i + LEN3 - 1);  //第i组的右边界 
        insertSort(a, left, right);  //组内进行插入排序 
        //将第i组中位数与数组a[]的第i个元素互换位置,方便递归select寻找中位数的中位数
        int mid = (left + right) / 2;
        swap(&a[begin + i], &a[mid]); 
    }
    int pivot = select_3(a, begin, begin + groups - 1, (groups + 1) / 2);  //找出中位数的中位数
    int p = partition(a, begin, end, pivot);  //用中位数的中位数作为划分的主元
    int leftNum =  p - begin;                 //低区元素的数量 
    if (k == leftNum + 1) {
    	return a[p];
	}
    else if (k <= leftNum) {
    	return select_3(a, begin, p - 1, k);  //在低区递归调用select来找出第k小的元素 
	}
    else {
    	return select_3(a, p + 1, end, k - leftNum -1);  //在高区递归调用select来找出第(k-leftNum-1)小的元素 
	}   
}


int main() {
    FILE *fp = fopen("data_1022.txt","r");            //打开文件 
	if (fp == NULL) {
		printf("Can not open the file!\n");
		exit(0);
	}
	int i = 0;
	while (fscanf(fp, "%d\n", &a[i]) != EOF) {  //读取文件中的数据到数组a[]中 
		i++;
	}
	fclose(fp);                                 //关闭文件 
	printf("Please enter an integer k, and you will get the k-th largest element in the array:\n");
	int k;
	scanf("%d", &k);
	printf("********************* Group length is 5, array size is 945800**********************\n");
	LARGE_INTEGER nFreq;
	LARGE_INTEGER nBeginTime;
	LARGE_INTEGER nEndTime;
	QueryPerformanceFrequency(&nFreq);
	QueryPerformanceCounter(&nBeginTime); 
    printf("The %dth largest element in the array is: %d\n", k, select_5(a, 0, i - 1, i - k + 1));
    QueryPerformanceCounter(&nEndTime);  //计时结束 
	double time = (double)(nEndTime.QuadPart - nBeginTime.QuadPart) / nFreq.QuadPart * 1000;
	printf("Running time: %lfms\n\n", time);
	
	printf("********************* Group length is 7, array size is 945800 **********************\n");
	QueryPerformanceFrequency(&nFreq);
	QueryPerformanceCounter(&nBeginTime); 
    printf("The %dth largest element in the array is: %d\n", k, select_7(a, 0, i - 1, i - k + 1));
    QueryPerformanceCounter(&nEndTime);  //计时结束 
	time = (double)(nEndTime.QuadPart - nBeginTime.QuadPart) / nFreq.QuadPart * 1000;
	printf("Running time: %lfms\n\n", time);
	
	printf("********************* Group length is 3, array size is 945800 **********************\n");
	QueryPerformanceFrequency(&nFreq);
	QueryPerformanceCounter(&nBeginTime); 
    printf("The %dth largest element in the array is: %d\n", k, select_3(a, 0, i - 1, i - k + 1));
    QueryPerformanceCounter(&nEndTime);  //计时结束 
	time = (double)(nEndTime.QuadPart - nBeginTime.QuadPart) / nFreq.QuadPart * 1000;
	printf("Running time: %lfms\n\n", time);
	
	printf("====================== Group length is 5, array size is 10000 ======================\n");
	QueryPerformanceFrequency(&nFreq);
	QueryPerformanceCounter(&nBeginTime); 
    printf("The %dth largest element in the array is: %d\n", k, select_5(a, 0, 9999, 10000 - k + 1));
    QueryPerformanceCounter(&nEndTime);  //计时结束 
	time = (double)(nEndTime.QuadPart - nBeginTime.QuadPart) / nFreq.QuadPart * 1000;
	printf("Running time: %lfms\n\n", time);
	
	printf("====================== Group length is 7, array size is 10000 ======================\n");
	QueryPerformanceFrequency(&nFreq);
	QueryPerformanceCounter(&nBeginTime); 
    printf("The %dth largest element in the array is: %d\n", k, select_7(a, 0, 9999, 10000 - k + 1));
    QueryPerformanceCounter(&nEndTime);  //计时结束 
	time = (double)(nEndTime.QuadPart - nBeginTime.QuadPart) / nFreq.QuadPart * 1000;
	printf("Running time: %lfms\n\n", time);
	
	printf("====================== Group length is 3, array size is 10000 ======================\n");
	QueryPerformanceFrequency(&nFreq);
	QueryPerformanceCounter(&nBeginTime); 
    printf("The %dth largest element in the array is: %d\n", k, select_3(a, 0, 9999, 10000 - k + 1));
    QueryPerformanceCounter(&nEndTime);  //计时结束 
	time = (double)(nEndTime.QuadPart - nBeginTime.QuadPart) / nFreq.QuadPart * 1000;
	printf("Running time: %lfms\n\n", time);
	
	printf("######################## Group length is 5, array size is 1000 ########################\n");
	QueryPerformanceFrequency(&nFreq);
	QueryPerformanceCounter(&nBeginTime); 
    printf("The %dth largest element in the array is: %d\n", k, select_5(a, 0, 999, 1000 - k + 1));
    QueryPerformanceCounter(&nEndTime);  //计时结束 
	time = (double)(nEndTime.QuadPart - nBeginTime.QuadPart) / nFreq.QuadPart * 1000;
	printf("Running time: %lfms\n\n", time);
	
	printf("######################## Group length is 7, array size is 1000 ########################\n");
	QueryPerformanceFrequency(&nFreq);
	QueryPerformanceCounter(&nBeginTime); 
    printf("The %dth largest element in the array is: %d\n", k, select_7(a, 0, 999, 1000 - k + 1));
    QueryPerformanceCounter(&nEndTime);  //计时结束 
	time = (double)(nEndTime.QuadPart - nBeginTime.QuadPart) / nFreq.QuadPart * 1000;
	printf("Running time: %lfms\n\n", time);
	
	printf("######################## Group length is 3, array size is 1000 ########################\n");
	QueryPerformanceFrequency(&nFreq);
	QueryPerformanceCounter(&nBeginTime); 
    printf("The %dth largest element in the array is: %d\n", k, select_3(a, 0, 999, 1000 - k + 1));
    QueryPerformanceCounter(&nEndTime);  //计时结束 
	time = (double)(nEndTime.QuadPart - nBeginTime.QuadPart) / nFreq.QuadPart * 1000;
	printf("Running time: %lfms\n\n", time);
    return 0;
}

4. 运行结果

程序使用的测试数据可在本文所附资源处或点击此链接下载。在linearSelect_kth.cpp中,设置分组长度为5。运行该程序,程序循环提示输入整数k,按下回车后会输出第k大的元素,直至输入一个负数或0,程序终止。

在linearSelect_kth_grouplenth_5_vs_7_vs_3.cpp中,大致比较了算法在分组长度为5、7、3以及在不同问题规模的情况下的运行时间。运行该程序,程序提示输入一个整数k,按下回车后,程序依次输出算法分组长度为5、7、3分别在数组长度为945800、10000、1000时的运行时间。

4.1 linearSelect_kth.cpp

  • 第1大: 9999990
  • 第5大: 9999940
  • 第7大: 9999915
  • 第90大: 9998974
  • 第100大:9998835

在这里插入图片描述

4.2 linearSelect_kth_grouplenth_5_vs_7_vs_3.cpp

在这里插入图片描述
在这里插入图片描述
多次运行发现,在数组长度为945800时,基本上运行时间都是组长为7 < 组长为5 < 组长为3。在问题规模减小时,三者的运行时间大小关系略有波动,猜测可能是由于组长为3的select算法是非线性的以及程序运行计时存在误差等。

  • 9
    点赞
  • 14
    收藏
    觉得还不错? 一键收藏
  • 打赏
    打赏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

fufufunny

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值