概念
topk问题,即给定一个数组,从该数组中取第k小(大)的元素,或获取前k小(大)的元素。本文给出的例子都是基于获取第K小的元素。
例子:int[] arr = {5, 2, 9, 3, 8, 4, 1, 6, 7}; int k=4; 通过某个算法后最后返回的是4。
1、O(n2)排序算法
很容易想到的,最简单的冒泡排序、选择排序,都可以将前k小的元素获取到。考虑优化一下,可以局部排序,使得最后的时间复杂度为O(n*k)
private static int getSort(int[] arr, int k) {
for (int i = 0; i < k; i++) {
for (int j = i + 1; j < arr.length; j++) {
if (arr[i] > arr[j]) {
int t = arr[i];
arr[i] = arr[j];
arr[j] = t;
}
}
}
return arr[k - 1];
}
2、O(n*log(n))排序算法:快速排序
先写一个快速排序,原始的快速排序模板为:
private static void quickSort(int[] arr, int left, int right) {
if (left < right) {
int pivot = getPivot(arr, left, right);
// 遍历左侧
quickSort(arr, left, pivot - 1);
// 遍历右侧
quickSort(arr, pivot + 1, right);
}
}
private static int getPivot(int[] arr, int left, int right) {
// 1、 选取基准元素
int pivotVal = arr[left];
// 2、遍历
while (left < right) {
// 3、右边的往前移动
while (left < right && arr[right] >= pivotVal) {
right--;
}
arr[left] = arr[right];
// 4、左边的往前移动
while (left < right && arr[left] <= pivotVal) {
left++;
}
arr[right] = arr[left];
}
// 5、重新定位基准
arr[left] = pivotVal;
return left;
}
我们针对 quickSort 方法做一个简单的改造,如果基准值==第k个节点(索引为k-1),则返回当前的值;如果基准值 > 第k个节点(索引为k-1),说明第k个节点再基准值左边,应该递归遍历左边;否则递归遍历右边。
改造后的快速排序算法如下:
private static int quickSort(int[] arr, int left, int right, int k) {
if (left < right) {
int pivot = getPivot(arr, left, right);
// 添加逻辑(稍微改一下代码)
if (pivot == k - 1) {
return arr[pivot];
} else if (k - 1 < pivot) {
return quickSort(arr, left, pivot - 1, k);
} else {
return quickSort(arr, pivot + 1, right, k);
}
}
return arr[left];
}
3、小根堆
很容易想到,可以建一个小根堆,把元素入堆,弹出前k个元素即可,但存在问题:创建堆的时间复杂度是O(n),删除的时间复杂度是O(k*log(n)),时间复杂度较高。该问题后续进一步优化。
private static int minHeap(int[] arr, int k) {
// 创建小根堆
PriorityQueue<Integer> minHeap = new PriorityQueue<>();
for (Integer data : arr) {
minHeap.add(data);
}
int res = arr[0];
for (int i = 1; i <= k; i++) {
res = minHeap.poll();
}
return res;
}
4、大根堆
换种思路,要求前k个最小元素,使用大根堆(创建k个元素的大根堆)。
为什么要建大根堆呢?
因为大根堆可以保证k个元素中的最大值在堆顶,然后再将剩下的n-k个元素与堆顶元素比较:
【如果该元素 < 堆顶元素,就将对顶元素删除,该元素入堆】
private static int maxHeap(int[] arr, int k) {
// 创建大根堆
PriorityQueue<Integer> maxHeap = new PriorityQueue<>(new Comparator<Integer>() {
@Override
public int compare(Integer o1, Integer o2) {
return o2 - o1;
}
});
// 让前k个元素入堆,创建大根堆
for (int i = 0; i < k; i++) {
maxHeap.add(arr[i]);
}
// 遍历后面的n-k个元素
for (int i = k; i < arr.length; i++) {
if (arr[i] < maxHeap.peek()) {
maxHeap.poll();
maxHeap.add(arr[i]);
}
}
// 获取堆顶元素
return maxHeap.peek();
}
5、简化一下大根堆代码,作为模版方法
我们发现,4、中的代码:
// 让前k个元素入堆,创建大根堆
for (int i = 0; i < k; i++) {
maxHeap.add(arr[i]);
}
// 遍历后面的n-k个元素
for (int i = k; i < arr.length; i++) {
if (arr[i] < maxHeap.peek()) {
maxHeap.poll();
maxHeap.add(arr[i]);
}
}
可以提取到一个循环中,用来简化一下代码,先让元素入堆,当堆大小 > k的时候,就移除堆顶元素即可,这样就保证了,堆中的元素一直都是k,简化后的代码如下:
// 让前k个元素入堆,创建大根堆
for (int i = 0; i < arr.length; i++) {
maxHeap.add(arr[i]);
if (k < maxHeap.size()) {
maxHeap.poll();
}
}
6、完整代码
附本文完整代码
package common.sort;
import java.util.Comparator;
import java.util.PriorityQueue;
public class TopK {
public static void main(String[] args) {
int[] arr = {5, 2, 9, 3, 8, 4, 1, 6, 7};
// 方法一:使用普通的排序,
// System.out.println(getSort(arr, 5));
// 方法二:使用快速排序
// System.out.println(quickSort(arr, 0, arr.length - 1, 5));
// 方法三:使用时间复杂度高的堆(小根堆)
// System.out.println(minHeap(arr,4));
// 方法四:使用大根堆(创建k个元素的大根堆)
// System.out.println(maxHeap(arr, 4));
// 方法五:使用大根堆,在方法四基础上优化一下写法,作为最终模版方法
System.out.println(maxHeapTemplate(arr, 4));
}
private static int maxHeapTemplate(int[] arr, int k) {
// 创建大根堆
PriorityQueue<Integer> maxHeap = new PriorityQueue<>((o1, o2) -> {
return o2 - o1;
});
// 让前k个元素入堆,创建大根堆
for (int i = 0; i < arr.length; i++) {
maxHeap.add(arr[i]);
if (k < maxHeap.size()) {
maxHeap.poll();
}
}
// 获取堆顶元素
return maxHeap.peek();
}
// 方法四:换种思路,要求前k个最小元素,使用大根堆(创建k个元素的大根堆)
// 为什么要建大根堆呢?因为大根堆可以保证k个元素中的最大值在堆顶,然后再将剩下的n-k个元素与堆顶元素比较:【如果该元素 < 堆顶元素,就将对顶元素删除,该元素入堆】
private static int maxHeap(int[] arr, int k) {
// 创建大根堆
PriorityQueue<Integer> maxHeap = new PriorityQueue<>(new Comparator<Integer>() {
@Override
public int compare(Integer o1, Integer o2) {
return o2 - o1;
}
});
// 让前k个元素入堆,创建大根堆
for (int i = 0; i < k; i++) {
maxHeap.add(arr[i]);
}
// 遍历后面的n-k个元素
for (int i = k; i < arr.length; i++) {
if (arr[i] < maxHeap.peek()) {
maxHeap.poll();
maxHeap.add(arr[i]);
}
}
// 获取堆顶元素
return maxHeap.peek();
}
// 方法3:使用小根堆,思想,先让所有元素入优先级队列,然后弹出k个
// 问题:创建堆的时间复杂度是O(n),删除的时间复杂度是O(k*log(n)),时间复杂度较高
private static int minHeap(int[] arr, int k) {
// 创建小根堆
PriorityQueue<Integer> minHeap = new PriorityQueue<>();
for (Integer data : arr) {
minHeap.add(data);
}
int res = arr[0];
for (int i = 1; i <= k; i++) {
res = minHeap.poll();
}
return res;
}
// 方法2:通过快速排序改造
private static int quickSort(int[] arr, int left, int right, int k) {
if (left < right) {
int pivot = getPivot(arr, left, right);
// 添加逻辑(稍微改一下代码)
if (pivot == k - 1) {
return arr[pivot];
} else if (k - 1 < pivot) {
return quickSort(arr, left, pivot - 1, k);
} else {
return quickSort(arr, pivot + 1, right, k);
}
}
return arr[left];
}
/* private static int[] quickSort(int[] arr, int left, int right, int k) {
if (left < right) {
int pivot = getPivot(arr, left, right);
if (pivot == k - 1) {
return Arrays.copyOfRange(arr, 0, k);
} else if (pivot > k - 1) {
return quickSort(arr, left, pivot - 1, k);
} else {
return quickSort(arr, pivot + 1, right, k);
}
}
return Arrays.copyOfRange(arr, 0, k);
}*/
private static int getPivot(int[] arr, int left, int right) {
// 1、 选取基准元素
int pivotVal = arr[left];
// 2、遍历
while (left < right) {
// 3、右边的往前移动
while (left < right && arr[right] >= pivotVal) {
right--;
}
arr[left] = arr[right];
// 4、左边的往前移动
while (left < right && arr[left] <= pivotVal) {
left++;
}
arr[right] = arr[left];
}
// 5、重新定位基准
arr[left] = pivotVal;
return left;
}
// 方法1:通过排序找(局部排序)
// 时间复杂度:O(n*k)
/*
private static int getSort(int[] arr, int k) {
for (int i = 0; i < k; i++) {
for (int j = i + 1; j < arr.length; j++) {
if (arr[i] > arr[j]) {
int t = arr[i];
arr[i] = arr[j];
arr[j] = t;
}
}
}
return arr[k - 1];
}
*/
}