问题描述:假设N(N>0)个元素是有序的,从所有N个元素中取出最大的K(K<=N)个元素。
先从算法角度分析这个问题:
简单粗暴的解法:
直接快速排序,从结果中选取前K个元素。
时间复杂度:
优点:思路简单,运用JDK自带的排序方法,写出来的代码不易出错,容易维护;算法的时间复杂度还算可以接受。
缺点:当N非常大(比如大于100亿),或者K相对于N非常小的时候,这种方法不是特别高效。
基于快排思想的解法:
虽然我们不会采用快速排序的算法来实现TOP-K问题,但我们可以利用快速排序的思想,在数组中随机找一个元素pivot,将数组分成两部分Sa和Sb,其中Sa的元素>=pivot,Sb的元素
若Sa中元素的个数小于K,其个数为len,则在Sb中查找K-len个数字
如此递归下去,不断把问题分解为更小的问题,直到求出结果。
时间复杂度:
优点:效率高。
缺点:需要至少N的额外内存,在N非常大的时候依然不是特别好的选择。
基于最小堆的解法:
为了查找Top K个大的数,我们可以使用最小堆来存储最大的K个元素。最小堆的堆顶元素就是最大K个数中最小的一个。每次考虑下一个数x时,如果x比堆顶元素小,则不需要改变原来的堆。如果想x比堆顶元素大,那么用x替换堆顶元素, 同时,在替换之后,x可能破坏最小堆的结构,需要调整堆来维持堆的性质。
时间复杂度:
优点:效率高,内存占用少。
缺点:比较复杂,容易出错。不过好在JDK自带一个最小堆实现PriorityQueue。随着N的增长,两种算法时间复杂度差距逐渐增大
采用最小堆的方法,最多需要K的额外内存,而不用考虑总体样本的容量,因此输入可以是一个Iterable类型(无需事先知道元素的规模)。
实现的方法应当是一个通用方法,采用泛型来保证这一点。只要元素类型是Comparable就可以应用该方法。同时,对于输入和输出,应该注意类型的上界和下界(extends和super)。
当我们通过最小堆选择出我们需要的最大的K个元素后,我们希望最后输出的结果是从大到小排列的,PriorityQueue的toArray方法并不能保证元素有序,我们只能逐个从PriorityQueue中取出K个元素来保证有序。这K个元素是从小到大排列的,要从大到小排列,我们需要借助一个中间数组倒序输出(中间数组也有一点讲究,由于Java的泛型是类型擦除的,但数组其实是区分元素类型的,所以只能使用Object数组,这一点和ArrayList内部的数组同理)。
由于最终阶段我们反正要遍历K个元素,考虑到我们有时候选择出K个元素后,并不是要直接使用这K个元素,而是要对这K个元素作某种处理后再使用,我们可以在方法中加一个map逻辑,在最后一步取出元素的过程中就完成处理;从而为方法使用者节省一次遍历。这一特性可以借助lambda表达式中Function来实现。
概括起来,实现要点就是:使用PriorityQueue的最小堆作容器;
不关心输入规模,Iterable即可;
采用合理的泛型约束参数的类型,来达到最大的泛用价值;
使用Function来提供结果映射的特性;
非常重要,鲁棒性的关键,别忘了参数检查!
下面给出一段参考代码:
import java.util.Arrays;
import java.util.List;
import java.util.Objects;
import java.util.PriorityQueue;
import java.util.function.Function;
final class TopKSelector {
private TopKSelector() {}
static , R> List selectThenMap(
Iterable extends T> src, int k, Function super T, ? extends R> function) {
Objects.requireNonNull(src, "null src!");
if (k <= 0) throw new IllegalArgumentException("'k' must be a positive number!");
PriorityQueue heap = new PriorityQueue<>(k);
int elementCount = 0;
for (T t : src) {
Objects.requireNonNull(t, "null element!");
if (elementCount++ < k) {
heap.add(t);
} else {
T top = heap.peek();
assert top != null;
if (top.compareTo(t) < 0) {
heap.poll();
heap.add(t);
}
}
}
return asList(heap, function);
}
static > List select(Iterable extends T> src, int k) {
return selectThenMap(src, k, Function.identity());
}
private static List asList(
PriorityQueue queue, Function super T, ? extends R> function) {
@SuppressWarnings("unchecked")
R[] a = (R[]) new Object[queue.size()];
for (int i = a.length - 1; i >= 0; i--) a[i] = function.apply(queue.poll());
return Arrays.asList(a);
}
}
来一段测试代码:
public static void main(String[] args) {
List src = Arrays.asList(9, 8, 6, 99, 8, 4, 0, -1, 5, 7);
List result = TopKSelector.selectThenMap(src, 4, Integer::toBinaryString);
System.out.println(result);
}
//[1100011, 1001, 1000, 1000]