top k问题
求数组中最小的k个数。
解法
求解top k问题的方法有很多种,这里主要介绍其中经典的两种,一种是基于堆的方法,另一种是利用快速排序中的partition过程的方法。
解法1
基于快速排序partition过程的方法:partition过程为每次选择一个数,查看数组中元素,如果数组中的元素比这个数小,放到数组左边,如果数组中元素比这个数大则放到数组的右边,最后将这个数和小于区域的最后一个位置进行交换,返回这个位置,则这个位置的左边为小于该数的所有元素,右边为大于等于该数的所有元素。因此多次进行partition过程,当返回的的位置等于k-1,则说明该位置坐标有k-1个元素比该元素小,该元素的右边则为所有大于该元素的其他元素,因此到此为止停止partition过程,返回数组中的前k个元素就找到最小的k个数了。
时间复杂度:O(n)
空间复杂度:O(1)(允许修改原始数组)
空间复杂度:O(n)(不允许修改原始数组,在复制后的数组上进行partition过程)
优点:时间复杂度低。
代码如下:
public static int[] getLeastNumbers(int[] arr, int k) {
if (arr.length < k) {
return arr;
}
if (arr == null || arr.length <= 0) {
return new int[0];
}
int left = 0;
int right = arr.length - 1;
int index = partition(arr, left, right);
while (index != k - 1) {
if (index > k - 1) {
index = partition(arr, left, index - 1);
} else {
index = partition(arr, index + 1, right);
}
}
return Arrays.copyOf(arr, k);
}
public static int partition(int[] arr, int left, int right) {
int less = left - 1;
int curIndex = left;
swap(arr, left + (int) (Math.random() * (right - left + 1)), right);
while (curIndex < right) {
if (arr[curIndex] < arr[right]) {
less++;
if (less != curIndex) {
swap(arr, less, curIndex);
}
curIndex++;
} else {
curIndex++;
}
}
less++;
swap(arr, less, right);
return less;
}
public static void swap(int[] arr, int i, int j) {
int temp = arr[i];
arr[i] = arr[j];
arr[j] = temp;
}
解法2
基于堆的方法:首先创建大顶堆,在堆中加入k个数组元素,然后判断其他数组元素与堆顶元素的大小,如果比堆顶元素小则将堆顶元素抛出,将该元素放入堆中;如果比堆顶元素大或相等,则更换数组下一个元素再次与堆顶元素进行比较,最终将堆中的元素返回,得到最小的k个数。在这里需要说明的是,由于Java中的优先级队列实现的是小顶堆,因此需要重新实现Comparator接口中的compare方法。
时间复杂度:O(nlogk)
空间复杂度:O(k)
优点:空间复杂度低,没有修改原始数组结构。
代码如下:
public static int[] getLeastNumbers(int[] arr, int k) {
if (k == 0) {
return new int[0];
} else if (arr.length <= k) {
return arr;
}
PriorityQueue<Integer> heap = new PriorityQueue<>(k, new Comparator<Integer>() {
@Override
public int compare(Integer o1, Integer o2) {
return o2 - o1;
}
});
for (int element : arr) {
if (heap.isEmpty() || heap.size() < k) {
heap.offer(element);
} else {
if (element < heap.peek()) {
heap.poll();
heap.offer(element);
}
}
}
// int[] array = heap.stream().mapToInt(i -> i).toArray();
int[] res = new int[k];
int index = 0;
while (!heap.isEmpty()) {
res[index++] = heap.poll();
}
return res;
}