TopK问题
在海量数据处理中,经常会有一类问题,求最小的K个数,或者,求最大的K个数,这类问题统称为TopK问题。
一、对此类问题的一些思考
如果数据量比较小的话,我么可以通过排序,然后截取其中我们想要的K个数,但假如数据量比较大的话,即使是考虑效率和资源的快速排序(时间复杂度 O ( n l g n ) O(nlgn) O(nlgn),空间复杂度 O ( 1 ) O(1) O(1)),也够喝一壶的,所以我们得考虑其他解决办法:
假如我们可以维护一种结构,这个结构里只存储k个元素,当要向这个结构插入数据tmp
时,我们先从这个结构里选出一个最有可能被排除得数据peek
,将其与待插入数据tmp
做比较,如果tmp
更适合存放在这个结构里,就将peek
从该结构里拿出,再将tmp
插入这个结构,当我们遍历完所有的数据之后,我们就得到了我们最想要的k个数据!
二、引出“堆”这种数据结构
有了这个想法之后,我们就需要考虑,那种数据结构可以满足这种需求呢,堆的概念就体现出来了,比方说大根堆,堆顶元素比其他所有元素都要大,假如将堆顶元素取出之后,堆会有一种调整方式,将一个合适的元素放到堆顶的位置。这正是我们想要的结果!
三、Java有没有相关的实现呢?
在集合中,优先级队列是队列的一种,它有着将数据通过指定策略(优先级)排序的一种能力,下面我们看下示例:
public class PriorityQueueDemo {
public static void main(String[] args) {
PriorityQueue<Integer> defaultQueue = new PriorityQueue<>();
defaultQueue.offer(10);
defaultQueue.offer(9);
System.out.println(defaultQueue.poll() + ", " + defaultQueue.poll());
}
}
运行该程序的输出为:
9, 10
默认情况,我们创建的是一个小根堆!
四、使用PriorityQueue解决topK问题
不多说了,上代码:
package com.zhang;
import java.util.ArrayList;
import java.util.Comparator;
import java.util.List;
import java.util.Objects;
import java.util.PriorityQueue;
public class TopK {
public static List<Integer> minK(List<Integer> data, int k) {
if (Objects.isNull(data)) {
return null;
}
if (data.size() <= k) {
data.sort(Comparator.naturalOrder());
return data;
}
// 获取最小的k个数,建个大根堆
PriorityQueue<Integer> queue = new PriorityQueue<>(k, Comparator.reverseOrder());
// 先插入k个数到堆里
for (int i = 0; i < k; ++i) {
queue.offer(data.get(i));
}
for (int i = k; i < data.size(); ++i) {
Integer peek = queue.peek();
Integer tmp = data.get(i);
// 若某个元素比大根堆堆顶元素(整个堆里最大的元素)还小
// 那就弹出堆顶元素,并将这个元素插入堆里
if (tmp < peek) {
queue.poll();
queue.offer(tmp);
}
}
ArrayList<Integer> result = new ArrayList<>(k);
while (!queue.isEmpty()) {
result.add(queue.poll());
}
return result;
}
public static List<Integer> maxK(List<Integer> data, int k) {
if (Objects.isNull(data)) {
return null;
}
if (data.size() <= k) {
data.sort(Comparator.reverseOrder());
return data;
}
// 建一个小根堆,这样堆顶元素最小
PriorityQueue<Integer> queue = new PriorityQueue<>(k, Comparator.naturalOrder());
for (int i = 0; i < k; ++i) {
queue.offer(data.get(i));
}
for (int i = k; i < data.size(); ++i) {
Integer peek = queue.peek();
Integer tmp = data.get(i);
if (peek < tmp) {
queue.poll();
queue.offer(tmp);
}
}
List<Integer> result = new ArrayList<>(k);
while (!queue.isEmpty()) {
result.add(queue.poll());
}
return result;
}
public static void main(String[] args) {
ArrayList<Integer> data = new ArrayList<>();
for (int i = 0; i < 20; ++i) {
data.add(i);
}
System.out.println("source data: " + data);
System.out.println("min10: " + minK(data, 10));
System.out.println("max10: " + maxK(data, 10));
}
}
该程序的输出为:
source data: [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19]
min10: [9, 8, 7, 6, 5, 4, 3, 2, 1, 0]
max10: [10, 11, 12, 13, 14, 15, 16, 17, 18, 19]
这可能不是一个好的示例,但是他可以解决我们的topK问题!