二叉堆的应用 —— TopK 问题求解


什么是 TOP-K

TOP-K 问题:即求数据结合中前K个最大的元素或者最小的元素,一般情况下数据量都比较大。

比如:专业前 10 名、世界 500 强、富豪榜、游戏中前 100 的活跃玩家等。

思路一

对于 Top-K 问题,能想到的最简单直接的方式就是排序:先将数组排为升序,然后输出前 k 个数。

代码实现

//调整算法里面的交换
void Swap(HPDataType* pa, HPDataType* pb) {
	HPDataType tmp = *pa;
	*pa = *pb;
	*pb = tmp;
}

//向下调整算法 --> 构建小堆
void AdjustDownSmall(HPDataType* a, size_t size, size_t root) {
	size_t parent = root;
	size_t child = 2 * parent + 1; //默认左孩子最小

	while (child < size)
	{
		//1.找出左右孩子中小的那个
		//如果右孩子存在,且右孩子小于size(元素个数),那么就把默认小的左孩子修改为右孩子
		if (child+1 < size && a[child+1] < a[child]) {
			++child;
		}

		//2.把小的孩子去和父亲比较,如果比父亲小,就交换
		if (a[child] < a[parent]) {
			Swap(&a[child], &a[parent]);
			parent = child;
			child = 2 * parent + 1;
		}
		else {
			break; //如果孩子大于等于父亲,那么直接跳出循环
		}
	}
}

int* getLeastNumbers(int* arr, int arrSize, int k, int* returnSize)
{
	*returnSize = k;
	int i = 0;
	//建小堆
	for (i = (arrSize - 1 - 1) / 2; i >= 0; i--)
	{
		AdjustDownSmall(arr, arrSize, i);
	}
	//排降序
	int end = arrSize - 1;
	while (end > 0)
	{
		Swap(&arr[0], &arr[end]);
		AdjustDownSmall(arr, end, 0);
		end--;
	}
	//将最大的k个数存入数组
	int* retArr = (int*)malloc(sizeof(int)*k);
	for (i = 0; i < k; i++)
	{
		retArr[i] = arr[i];
	}
	return retArr;//返回最大的k个数
}

但是排序的时间复杂度为: O ( N ∗ l o g N ) O(N*logN) O(NlogN)

空间复杂度为:O(1)

那么还能不能进一步优化呢?

思路二

建立 N 个数的大堆,然后 Pop K 次,就可以选出最大的前 K 个。

Pop 就是删除的意思,比如堆顶最大的元素是 10,我们先把它取出来,然后再 Pop 掉,此时堆顶就是 次大 的元素。

代码示例

//调整算法里面的交换
void Swap(HPDataType* pa, HPDataType* pb) {
	HPDataType tmp = *pa;
	*pa = *pb;
	*pb = tmp;
}

//向下调整算法 --> 构建大堆
void AdjustDownBig(HPDataType* a, size_t size, size_t root) {
	size_t parent = root;
	size_t child = 2 * parent + 1; //默认左孩子最大

	while (child < size)
	{
		//1.找出左右孩子中小的那个
		//如果右孩子存在,且右孩子小于size(元素个数),那么就把默认小的左孩子修改为右孩子
		if (child + 1 < size && a[child + 1] > a[child]) {
			++child;
		}

		//2.把小的孩子去和父亲比较,如果比父亲小,就交换
		if (a[child] > a[parent]) {
			Swap(&a[child], &a[parent]);
			parent = child;
			child = 2 * parent + 1;
		}
		else {
			break; //如果孩子大于等于父亲,那么直接跳出循环
		}
	}
}

//TopK算法
int* getLeastNumbers(int* arr, int arrSize, int k, int* returnSize)
{
	*returnSize = k;
	int i = 0;
	//建大堆
	for (i = (arrSize - 1 - 1) / 2; i >= 0; i--)
	{
		AdjustDown(arr, arrSize, i);
	}
	//将最大的k个数存入数组
	int* retArr = (int*)malloc(sizeof(int)*k);
	int end = arrSize - 1;
	for (i = 0; i < k; i++)
	{
		retArr[i] = arr[0];//取堆顶数据
		Swap(&arr[0], &arr[end]);//交换堆顶数据与最后一个数据
		//进行一次向下调整,不把最后一个数据看作待调整的数据,所以待调整数据为end=arrSize-1
		AdjustDown(arr, end, 0);
		end--;//最后一个数据的下标改变
	}
	return retArr;//返回最大的k个数
}

注意:这里面是要建大堆

时间复杂度为: O ( N + l o g N ∗ K ) O(N+logN*K) O(N+logNK)

空间复杂度为:O(1)

看起来貌似还不错呀!但是忽略了一个问题:如果我们的 N 非常非常大呢?并且远大于 K

思路三

比如:100 亿个数里面找出最大的前 10 个最大的数呢?

存储 100 亿个整数究竟需要多大的内存空间?让咱们来大概估算一下:

1G = 1024MB

1024MB = 1024 * 1024KB

1024*1024KB = 1024*1024*1024Byte

于是可以得出 1 GB大概有 230 个字节,也就是说 1 GB大概等于 10 亿个字节。

存储 100 亿个整型需要 400 亿个字节,所以存储 100 亿个整型数据大约需要 40 G左右的内存空间。

所以前面两种算法并不适合用于这种海量数据处理。

因为如果数据量非常大,排序就不太可取了,可能数据都不能一下子全部加载到内存中。

最佳的方式就是用堆来解决,基本思路如下:

1)先将数组的前 K 个数建为 小堆
 
2)将数组剩下的 N-K 个数依次与堆顶元素进行比较,若比堆顶元素大,则将堆顶元素替换为该元素,然后再进行一次向下调整,使其仍为小堆。
 
最后堆里面的 K 个数就是最大的前 K 个。

代码实现

//调整算法里面的交换
void Swap(HPDataType* pa, HPDataType* pb) {
	HPDataType tmp = *pa;
	*pa = *pb;
	*pb = tmp;
}

//向下调整算法 --> 构建小堆
void AdjustDownSmall(HPDataType* a, size_t size, size_t root) {
	size_t parent = root;
	size_t child = 2 * parent + 1; //默认左孩子最小

	while (child < size)
	{
		//1.找出左右孩子中小的那个
		//如果右孩子存在,且右孩子小于size(元素个数),那么就把默认小的左孩子修改为右孩子
		if (child+1 < size && a[child+1] < a[child]) {
			++child;
		}

		//2.把小的孩子去和父亲比较,如果比父亲小,就交换
		if (a[child] < a[parent]) {
			Swap(&a[child], &a[parent]);
			parent = child;
			child = 2 * parent + 1;
		}
		else {
			break; //如果孩子大于等于父亲,那么直接跳出循环
		}
	}
}

void PrintTopK(int* a, int n, int k) {
	//1.建堆 --> 用a中前k个元素建小堆
	int* kminHeap = (int*)malloc(sizeof(int) * k);
	if (kminHeap == NULL) {
		printf("malloc fail\n");
		exit(-1);
	}
	//把数组的前K个数放进去(放进去的就是数组的前K个数,大小不一)
	for (int i = 0; i < k; ++i) {
		kminHeap[i] = a[i];
	}

	//k-1是最后一个数的下标,(k-1-1)/2就是倒数的第一个非叶子节点,也就是最后一个节点的父亲节点
	for (int j = (k - 1 - 1) / 2; j >= 0; --j) {
		//建小堆
		AdjustDownSmall(kminHeap, k, j);
	}

	//2.将剩余的n-k个元素依次与堆顶元素比较
	for (int i = k; i < n; ++i) {
		if (a[i] > kminHeap[0]) {
			kminHeap[0] = a[i];
			AdjustDownSmall(kminHeap, k, 0);
		}
	}
	free(kminHeap);
}

注意:这里是建小堆,如果建一个大堆,万一堆顶的数据就是 N 个数中最大的那个,那么不论怎么比较,你后面的 N-K 个元素都没法进堆,因此要用小堆来解决这个问题。

写好了以后,我这边那一段程序来测试一下,下面这段代码中,我生成了 10000 个随机数,并且给定了 10 个超过 10000 的数据,然后我们来测试看看能不能找到

代码示例

//向下调整算法 --> 构建小堆
void AdjustDownSmall(HPDataType* a, size_t size, size_t root) {
	size_t parent = root;
	size_t child = 2 * parent + 1; //默认左孩子最小

	while (child < size)
	{
		//1.找出左右孩子中小的那个
		//如果右孩子存在,且右孩子小于size(元素个数),那么就把默认小的左孩子修改为右孩子
		if (child+1 < size && a[child+1] < a[child]) {
			++child;
		}

		//2.把小的孩子去和父亲比较,如果比父亲小,就交换
		if (a[child] < a[parent]) {
			Swap(&a[child], &a[parent]);
			parent = child;
			child = 2 * parent + 1;
		}
		else {
			break; //如果孩子大于等于父亲,那么直接跳出循环
		}
	}
}

//TopK算法
void PrintTopK(int* a, int n, int k) {
	//1.建堆 --> 用a中前k个元素建小堆
	int* kminHeap = (int*)malloc(sizeof(int) * k);
	if (kminHeap == NULL) {
		printf("malloc fail\n");
		exit(-1);
	}

	for (int i = 0; i < k; ++i) {
		kminHeap[i] = a[i];
	}

	//k-1是最后一个数的下标,(k-1-1)/2就是倒数的第一个非叶子节点,也就是最后一个节点的父亲节点
	for (int j = (k - 1 - 1) / 2; j >= 0; --j) {
		//建小堆
		AdjustDownSmall(kminHeap, k, j);
	}

	//2.将剩余的n-k个元素依次与堆顶元素比较
	for (int i = k; i < n; ++i) {
		if (a[i] > kminHeap[0]) {
			kminHeap[0] = a[i];
			AdjustDownSmall(kminHeap, k, 0);
		}
	}
	
	//打印最大的K个数
	for (int j = 0; j < k; ++j)
	{
		printf("%d ", kminHeap[j]);
	}
	printf("\n");
	free(kminHeap);
}

//测试函数
void TestTopk()
{
	int n = 10000;
	int* a = (int*)malloc(sizeof(int) * n);
	srand(time(0));
	for (int i = 0; i < n; ++i)
	{
		a[i] = rand() % 1000000;
	}
	a[5] = 1000000 + 1;
	a[1231] = 1000000 + 2;
	a[531] = 1000000 + 3;
	a[5121] = 1000000 + 4;
	a[115] = 1000000 + 5;
	a[2335] = 1000000 + 6;
	a[9999] = 1000000 + 7;
	a[76] = 1000000 + 8;
	a[423] = 1000000 + 9;
	a[3144] = 1000000 + 10;
	PrintTopK(a, n, 10);
}

//主函数
int main()
{
	TestTopk();
	return 0;
}

调试运行

在这里插入图片描述

可以看到成功的找到了 10000 个数当中最大的前 10 个,只不过它们没有进行排序。

这种算法的时间复杂度为: O ( K + l o g K ∗ ( N − K ) ) O(K+logK*(N-K)) O(K+logK(NK)),当 N 非常大时,并且 K 很小,那么基本就是 O ( N ) O(N) O(N)

空间复杂度: O ( K ) O(K) O(K)

这样内存的消耗就取决于 K 的值了,就算要在 100 亿个数里面找到最大的 100 个数据,也只需要 400 字节的内存空间了!

  • 1
    点赞
  • 7
    收藏
    觉得还不错? 一键收藏
  • 打赏
    打赏
  • 1
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

Albert Edison

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值