第K大/Top K及其简单实现

转载请注明出处:http://blog.csdn.net/u012469987/。


见网上第K大多数只给思路,没给实现,我就来填坑了。

update 2023-07-23 增加leetCode的Java版,区别是,二分法无法过leetCode,因为处理不了重复元素
update 2017-09-23 有同学反馈说面试遇到这个题,博文给了助攻,哈预料之中。

Top K 和第K大基本等价,以下我们以第K大为例且假设第K大一定存在,Top K 可以在第k大基础上稍微改动获得。
本文介绍6种方法,只考虑实现功能,不做异常判断,面试的话快排和最小堆的方法比较不错,测试提交的话可以去Leetcode,或者直接拿最下面的数据生成代码去对拍跑。

快排的思想 近似O(n)

调用降序快排的partition函数,设区间为[low,high],返回index,则index左边都是大于data[index]的。

  1. 若index及index左边数字有k个则data[index]就是第k大,index及其左边元素为Top K元素
  2. 左边数字大于k个则继续在[low,index]里找
  3. 左边数字小于k个则去右边[index+1,high]找 k - 左边数字个数
#include <cstdio>
#include <iostream>
using namespace std;
const int maxn = 1e5 + 5;
//改为 data[high] >= key 和 data[low] <= key 则为第k小
int part(int *data, int low, int high) {
	int key = data[low];
	while (low < high) {
		while (low < high && data[high] <= key) high--;
		data[low] = data[high];
		while (low < high && data[low] >= key) low++;
		data[high] = data[low] ;
	}
	data[low] = key;
	return low;
}
int k_th(int *data, int k, int low, int high) {
	int pos = part(data, low, high);
	int cnt = pos - low + 1;  //[low,pos]元素个数
	if (cnt == k) return data[pos];
	else if (cnt < k) return k_th(data, k - cnt, pos + 1, high);
	else return k_th(data, k, low, pos);
}
int k_th(int *data, int n, int k) {
	if(k < 1 || k > n) return -1;
	return k_th(data, k, 0, n - 1);	//闭区间
	//遍历data[0,k)即可获得top K,但不能保证有序
}

int main() {

	// int data[] = {1, 5, 6, 7, 3, 2, 10, 9, 0, 231, 3214, 61};
	// int n = sizeof(data) / sizeof(int);
	// int k = 2;
	// cout << k_th(data, n, k) << endl;
	
    // freopen("in.txt","r",stdin);
    // freopen("out.txt","w",stdout);
	int n, k, data[maxn];
	std::ios::sync_with_stdio(false);
	while (cin >> n >> k) {
		for (int i = 0; i < n; ++i) {
			cin >> data[i];
		}
		cout << k_th(data, n, k) << endl;
	}
	return 0;
}

小根堆 O(nlogk)

维护一个k个元素的小根堆,保持堆里元素为最大的K个且堆顶为第k大(堆里最小的),扫一遍数据,若堆里个数小于k则插入,否则看新的数和堆顶数大小关系:

  1. 若新来的数小于等于堆顶,即新元素比Top K里最小的还小,则新来的数显然不可能是前k大
  2. 若新来的数大于堆顶,则删掉堆顶,将新数字放到堆里且调整堆来保持堆的属性

由于实现堆代码量较多,我们可以用C++的优先队列、set等代替手工堆偷跑,当然这里也提供了手动实现版。

#include <cstdio>
#include <vector>
#include <queue>
#include <iostream>
using namespace std;
const int maxn = 1e5 + 5;
//维持一个k大小的最小堆,根据新元素和堆顶大小决定要不要加入堆且删堆顶
// O(nlogk)
int biggest_k_th(int *data, int n, int k) {
    priority_queue<int, vector<int>, greater<int> >q;   //小根堆
    while (!q.empty()) q.pop();

    for (int i = 0; i < n; ++i) {
        if (q.size() < k) {
            q.push(data[i]);
        } else if (data[i] > q.top()) {
            q.pop();
            q.push(data[i]);
        }
    }
    //取k次q.top()且pop()k次即为有序的前K大
    return q.top();
}

int smallest_k_th(int *data, int n, int k) {
    priority_queue<int>q;   //大根堆
    while (!q.empty()) q.pop();

    for (int i = 0; i < n; ++i) {
        if (q.size() < k) {
            q.push(data[i]);
        } else if (data[i] < q.top()) {
            q.pop();
            q.push(data[i]);
        }
    }
    return q.top();
}

int main() {
    // freopen("in.txt","r",stdin);
    // freopen("out.txt","w",stdout);
    std::ios::sync_with_stdio(false);
    int n, k, data[maxn];
    while (cin >> n >> k) {
        for (int i = 0; i < n; ++i) {
            cin >> data[i];
        }
        cout << biggest_k_th(data, n, k) << endl;
    }
    return 0;
}

手动实现版

#include <iostream>
#include <cstdio>
#include <cstring>
using namespace std;
const int maxn = 1e5 + 5;
const int maxK = 1e5 + 5;

int heapCnt = 0;
int heap[maxK];

void adjust(int *heap, int begin, int end) {	//[begin,end)
	int cur = begin;
	int son = 2 * cur + 1;
	while (son < end) {
		if (son + 1 < end && heap[son] > heap[son + 1]) son++;
		if (heap[cur] < heap[son]) return;
		swap(heap[son], heap[cur]);
		cur = son;
		son = 2 * cur + 1;
	}
}

void buildHeap(int *heap, int k) {	//[heap,heap+k) 开区间
	for (int i = k / 2; i >= 0;  --i) {
		adjust(heap, i, k);
	}
}

int k_th(int *data, int n, int k) {
	heapCnt = 0;
	for (int i = 0; i < n; ++i) {
		if (heapCnt < k) {
			heap[heapCnt++] = data[i];
			if (heapCnt == k) {
				buildHeap(heap, k);	//data[0,k)共k个
			}
		} else {
			if (data[i] > heap[0]) {
				heap[0] = data[i];
				adjust(heap, 0, heapCnt);
			}
		}
	}
	return heap[0];
}

int main() {
	// freopen("in.txt", "r", stdin);
	// freopen("out.txt", "w", stdout);
	int n, k, data[maxn];
	std::ios::sync_with_stdio(false);
	while (cin >> n >> k) {
		for (int i = 0; i < n; ++i) {
			cin >> data[i];
		}
		cout << k_th(data, n, k) << endl;
	}
	return 0;
}

计数排序 O(n)

按照计数排序思想给数据的值计数,再从数据的最大值往最小值遍历,则总次数大于等于k的那个数为第k大,见代码一目了然。
优点:速度快且不用库也代码量少,妥妥的O(n)
缺点:只适用于数值不大的情况,当然你用hashmap这类库计数的话就没这问题了。

#include <iostream>
#include <cstdio>
#include <cstring>
using namespace std;
const int maxn = 1e5 + 5;
const int maxVal = 1e5 + 5;	//O(n) 适用于数据值不大的情况

int k_th(int *data, int n, int k) {
	int mmin = data[0], mmax = data[0];
	int times[maxVal];
	memset(times,0,sizeof(times));

	for (int i = 0; i < n; ++i) {
		mmin = min(mmin, data[i]);
		mmax = max(mmax, data[i]);
		times[data[i]]++;
	}

	int cnt = 0;
	for (int i = mmax; i >= mmin; --i) {
		cnt += times[i];
		if (cnt >= k) {	// >= 是因为第k大的数可能有若干个,找第一个
			return i;
		}
		//反过来遍历则为第k小
		//每次输出times[i]次i,注意下边界就出了有序前k大
	}
	return -1;
}

int main() {
    // freopen("in.txt","r",stdin);
    // freopen("out.txt","w",stdout);
	int n, k, data[maxn];
	std::ios::sync_with_stdio(false);
	while (cin >> n >> k) {
		for (int i = 0; i < n; ++i) {
			cin >> data[i];
		}
		cout<< k_th(data, n, k) <<endl;
	}
	return 0;
}

二分 O(nlogn)

假设第K大的数字是val,那么val肯定在一个数字区间里,我们叫 [l,r] ,我们就二分这个区间和val。
最开始l=所有数的最小值,r=最大值,假设当前值是mid,如果所有数据中大于等于mid的数字至少k个,说明当前数值可能是答案(若mid存在的情况则将区间调为[mid,r],mid不存在的话就改为[mid+1,r]),否则mid偏大,在[l,mid-1]里查找;二分不会的可见这篇文章。
二分本身是需要有序的,但我们二分的是答案值,int数字本身就有排序效果。

#include <iostream>
#include <cstdio>
#include <cstring>
using namespace std;
const int maxn = 1e5 + 5;
const int maxVal = 1e5 + 5;

bool ok(int *data, int n, int k, int mid) {
	int cnt = 0;
	for (int i = 0; i < n; ++i) {
		if (data[i] >= mid) cnt++;
	}
	return cnt >= k;
}
int k_th(int *data, int n, int k) {
	int mmin = data[0], mmax = data[0];
	bool vis[maxVal];
	memset(vis, false, sizeof(vis));

	for (int i = 0; i < n; ++i) {
		mmin = min(mmin, data[i]);
		mmax = max(mmax, data[i]);
		vis[data[i]] = true;
	}

	int l = mmin, r = mmax;
	while (l < r) {
		int mid = (l + r + 1) / 2;
		if (ok(data, n, k, mid)) {
			if (!vis[mid]) l = mid + 1;
			else l = mid;
		} else {
			r = mid - 1;
		}
	}
	return l;
}

int main() {
	// freopen("in.txt", "r", stdin);
	// freopen("out.txt", "w", stdout);
	int n, k, data[maxn];
	std::ios::sync_with_stdio(false);
	while (cin >> n >> k) {
		for (int i = 0; i < n; ++i) {
			cin >> data[i];
		}
		cout << k_th(data, n, k) << endl;
	}
	return 0;
}

暴力式选择/冒泡排序 O(kn)

特慢做法:排序k个,每次遍历n个元素,O(k*n)

#include <iostream>
#include <cstdio>
#include <cstring>
using namespace std;
const int maxn = 1e5 + 5;

int k_th(int *data, int n, int k) {
    for (int i = 0; i < k; ++i) {
        for (int j = 0; j < n - i - 1; ++j) {
            if (data[j] > data[j + 1]) {
                swap(data[j], data[j + 1]);
            }
        }
    }
	return data[n-k];
}

int main() {
	// freopen("in.txt", "r", stdin);
	// freopen("out.txt", "w", stdout);
	int n, k, data[maxn];
	std::ios::sync_with_stdio(false);
	while (cin >> n >> k) {
		for (int i = 0; i < n; ++i) {
			cin >> data[i];
		}
		cout << k_th(data, n, k) << endl;
	}
	return 0;
}

真暴力排序O(nlogn)

排完取 data[k] ,这么暴力就不说了。

数据生成代码

生成10组数据,每组一个n(范围:[a_n,b_n]),然后n个数 [a,b]。

#include <cstdio>
#include <cmath>
#include <cstdlib>
using namespace std;
int rand_ab(int a, int b) { //[a,b]
    return a + rand() % (b + 1 - a);
}
void make(){	
    int a_n = 10000, b_n = 100000;
    int a = 1, b = 10000;
    for (int i = 0; i < 10; ++i) {
        int n = rand_ab(a_n, b_n);
        printf("%d ", n);
        int a_k = 1, b_k = n;
        printf("%d\n", rand_ab(a_k,b_k));
        printf("%d", rand_ab(a, b));
        for (int i = 1; i < n; ++i) {
            printf(" %d", rand_ab(a, b));
        }
        printf("\n");
    }
}
int main() {
    // freopen("out.txt","w",stdout);
    make();
    return 0;
}

Java-leetcode版合集

https://leetcode.cn/problems/kth-largest-element-in-an-array/


public class C215 {
    public int part(int[] nums, int low,int high) {
        int key = nums[low];
        while (low < high){
            while (low < high && key >= nums[high]  ) high--;
            nums[low] = nums[high];
            while (low < high && key <= nums[low]  ) low++;
            nums[high] = nums[low];
        }
        nums[low] = key;
        return low;
    }
    public int pSort(int[] nums, int low,int high, int k) {
        int mid = part(nums,low,high);
        int leftCnt = mid - low + 1;  //[low,pos]元素个数
        if( leftCnt  == k ){
            return nums[mid];
        }
        if( leftCnt < k ){
            return pSort(nums,mid+1,high,k-leftCnt);
        }
        return pSort(nums,low,mid,k);
    }
    public int findKthLargest(int[] nums, int k) {
        return pSort(nums,0,nums.length-1,k);
//        return findKthLargestByTimeSortWithBase(nums,k);
//        return findKthLargestByHeapSort(nums,k);
        //二分方法不行,处理不了重复元素,因为计算数量会错,比如[99,99]查第1大,如果处理了重复元素,结果又不对了,所以只能在无重复元素情况下用
//        return findKthLargestByBSearch(nums,k);
    }
    public int bSearchJudge(int[] nums, int ans ,int k) {
        int cnt = 0;
        for (int i = 0; i < nums.length; i++) {
            if(nums[i] >= ans){
                cnt++;
            }
            if(cnt > k){
                return 1;
            }
        }
        return cnt == k ? 0 : -1;
    }
    public int findKthLargestByBSearch(int[] nums, int k) {
        int base = 10001;
        int low = base,high = -base;
        Set<Integer> set = new HashSet<>();
        for (int i = 0; i < nums.length; i++) {
            high = Math.max(high,nums[i]);
            low = Math.min(low,nums[i]);
            set.add(nums[i]);
        }
        while (low <= high){
            int mid = (low+high)/2;
            int judge = bSearchJudge(nums,mid,k);
            if(judge == 0){
                if(set.contains(mid)) {
                    return mid;
                }else{
                    low = mid+1;
                }
            }else if(judge > 0){
                low = mid+1;
            } else{
                high = mid-1;
            }
        }
        return -1;
    }

    public int findKthLargestByHeapSort(int[] nums, int k) {
        PriorityQueue<Integer> pq = new PriorityQueue<>();
        for (int i = 0; i < nums.length; i++) {
            if(pq.size() < k){
                pq.offer(nums[i]);
            }else{
                int top = pq.peek();
                if(top < nums[i]){
                    pq.poll();
                    pq.offer(nums[i]);
                }
            }
            //优化点:利用优先队列自带的判断,大于K个就直接进去当前元素再踢掉  if (heap.size() > k) {
            //                heap.poll();
            //            }
        }
        return pq.poll();
    }
    public int findKthLargestByTimeSortWithBase(int[] nums, int k) {
        int base = 10005;
        int n = 10005;
        int times[] = new int[n*2];
        for (int i = 0; i < nums.length; i++) {
            times[nums[i]+base]++;
        }
        int cnt = 0;
        for (int i = n*2-1; i >= 0; i--) {
            cnt += times[i];
            if(cnt >= k){
                return i-base;
            }
        }
        return -1;
    }
    public int findKthLargestByTimeSort(int[] nums, int k) {
        int n = 10005;
        int times[] = new int[n];
        for (int i = 0; i < nums.length; i++) {
            times[nums[i]]++;
        }
        int cnt = 0;
        for (int i = n-1; i >= 0; i--) {
            cnt += times[i];
            if(cnt >= k){
                return i;
            }
        }
        return -1;
    }

    public static void main(String[] args) {
//        int[] nums = new int[]{3, 2, 1, 5, 6, 4};
//        System.out.println(new C215().findKthLargest(nums, 2));
        int[] nums2 = new int[]{-1111,-222,3333,555,0,-3};
        int[] nums3 = new int[]{99,99};
        System.out.println(new C215().findKthLargest(nums2, 5));
        System.out.println(new C215().findKthLargest(nums3, 1));
    }
}

  • 6
    点赞
  • 7
    收藏
    觉得还不错? 一键收藏
  • 2
    评论
PyTorch的topk函数是用于返回输入张量中指定维度上的前k个最大值及其对应的索引。它的函数签名为torch.topk(input, k, dim=None, largest=True, sorted=True, out=None),返回一个元组,包含最大的k个值组成的张量和它们在输入张量中的索引组成的长整型张量。其中,input是输入张量,k是要返回的最大值的个数,dim是指定的维度,largest决定是否返回最大值(默认为True),sorted决定是否返回排序的结果(默认为True),out是输出的张量。 例如,如果我们有一个输入张量input为[5, 9, 3, 2, 7],我们想要找出其中最大的3个值及其索引,我们可以使用torch.topk(input, 3)。这将返回一个包含[9, 7, 5]的张量和一个包含[1, 4, 0]的长整型张量,分别表示最大的3个值和它们在输入张量中的索引。 在具体的代码中,maxk = max(topk)用于获取topk列表中的最大值,而output.topk(maxk, 1, True, True)则是对output进行topk操作,返回最大值和对应的索引。这种用法可以帮助我们在代码中获取最大的k个值及其索引。 总结来说,PyTorch的topk函数可以帮助我们在指定维度上找出输入张量中的最大值及其对应的索引。这在许多机器学习和深度学习任务中非常有用。如果想要了解更多关于topk函数的用法,可以参考PyTorch官方中文文档或者一篇介绍topk函数用法的文章。<span class="em">1</span><span class="em">2</span><span class="em">3</span> #### 引用[.reference_title] - *1* *2* [pytorch 中的topk函数](https://blog.csdn.net/u012505617/article/details/103711019)[target="_blank" data-report-click={"spm":"1018.2226.3001.9630","extra":{"utm_source":"vip_chatgpt_common_search_pc_result","utm_medium":"distribute.pc_search_result.none-task-cask-2~all~insert_cask~default-1-null.142^v93^chatsearchT3_1"}}] [.reference_item style="max-width: 50%"] - *3* [PyTorch中topk函数的用法详解](https://download.csdn.net/download/weixin_38628150/12856649)[target="_blank" data-report-click={"spm":"1018.2226.3001.9630","extra":{"utm_source":"vip_chatgpt_common_search_pc_result","utm_medium":"distribute.pc_search_result.none-task-cask-2~all~insert_cask~default-1-null.142^v93^chatsearchT3_1"}}] [.reference_item style="max-width: 50%"] [ .reference_list ]

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 2
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值