一步一步优化TopK问题(Java版本)

概念

topk问题,即给定一个数组,从该数组中取第k小(大)的元素,或获取前k小(大)的元素。本文给出的例子都是基于获取第K小的元素。

例子:int[] arr = {5, 2, 9, 3, 8, 4, 1, 6, 7}; int k=4; 通过某个算法后最后返回的是4。

1、O(n2)排序算法

很容易想到的,最简单的冒泡排序、选择排序,都可以将前k小的元素获取到。考虑优化一下,可以局部排序,使得最后的时间复杂度为O(n*k)

    private static int getSort(int[] arr, int k) {
        for (int i = 0; i < k; i++) {
            for (int j = i + 1; j < arr.length; j++) {
                if (arr[i] > arr[j]) {
                    int t = arr[i];
                    arr[i] = arr[j];
                    arr[j] = t;
                }
            }
        }
        return arr[k - 1];
    }

2、O(n*log(n))排序算法:快速排序

先写一个快速排序,原始的快速排序模板为:

      private static void quickSort(int[] arr, int left, int right) {
        if (left < right) {
            int pivot = getPivot(arr, left, right);
            // 遍历左侧
            quickSort(arr, left, pivot - 1);
            // 遍历右侧
            quickSort(arr, pivot + 1, right);
        }
    }

    private static int getPivot(int[] arr, int left, int right) {
        // 1、 选取基准元素
        int pivotVal = arr[left];

        // 2、遍历
        while (left < right) {

            // 3、右边的往前移动
            while (left < right && arr[right] >= pivotVal) {
                right--;
            }
            arr[left] = arr[right];

            // 4、左边的往前移动
            while (left < right && arr[left] <= pivotVal) {
                left++;
            }
            arr[right] = arr[left];
        }

        // 5、重新定位基准
        arr[left] = pivotVal;
        return left;
    }

我们针对 quickSort 方法做一个简单的改造,如果基准值==第k个节点(索引为k-1),则返回当前的值;如果基准值 > 第k个节点(索引为k-1),说明第k个节点再基准值左边,应该递归遍历左边;否则递归遍历右边。

改造后的快速排序算法如下:

    private static int quickSort(int[] arr, int left, int right, int k) {
        if (left < right) {
            int pivot = getPivot(arr, left, right);

            // 添加逻辑(稍微改一下代码)
            if (pivot == k - 1) {
                return arr[pivot];
            } else if (k - 1 < pivot) {
                return quickSort(arr, left, pivot - 1, k);
            } else {
                return quickSort(arr, pivot + 1, right, k);
            }
        }
        return arr[left];
    }

3、小根堆

很容易想到,可以建一个小根堆,把元素入堆,弹出前k个元素即可,但存在问题:创建堆的时间复杂度是O(n),删除的时间复杂度是O(k*log(n)),时间复杂度较高。该问题后续进一步优化。

    private static int minHeap(int[] arr, int k) {
        // 创建小根堆
        PriorityQueue<Integer> minHeap = new PriorityQueue<>();
        for (Integer data : arr) {
            minHeap.add(data);
        }
        int res = arr[0];
        for (int i = 1; i <= k; i++) {
            res = minHeap.poll();
        }
        return res;
    }

4、大根堆

换种思路,要求前k个最小元素,使用大根堆(创建k个元素的大根堆)。

为什么要建大根堆呢?

因为大根堆可以保证k个元素中的最大值在堆顶,然后再将剩下的n-k个元素与堆顶元素比较:

【如果该元素 < 堆顶元素,就将对顶元素删除,该元素入堆】

private static int maxHeap(int[] arr, int k) {
        // 创建大根堆
        PriorityQueue<Integer> maxHeap = new PriorityQueue<>(new Comparator<Integer>() {
            @Override
            public int compare(Integer o1, Integer o2) {
                return o2 - o1;
            }
        });
        // 让前k个元素入堆,创建大根堆
        for (int i = 0; i < k; i++) {
            maxHeap.add(arr[i]);
        }
        // 遍历后面的n-k个元素
        for (int i = k; i < arr.length; i++) {
            if (arr[i] < maxHeap.peek()) {
                maxHeap.poll();
                maxHeap.add(arr[i]);
            }
        }
        // 获取堆顶元素
        return maxHeap.peek();
    }

5、简化一下大根堆代码,作为模版方法

我们发现,4、中的代码:

// 让前k个元素入堆,创建大根堆
        for (int i = 0; i < k; i++) {
            maxHeap.add(arr[i]);
        }
        // 遍历后面的n-k个元素
        for (int i = k; i < arr.length; i++) {
            if (arr[i] < maxHeap.peek()) {
                maxHeap.poll();
                maxHeap.add(arr[i]);
            }
        }

可以提取到一个循环中,用来简化一下代码,先让元素入堆,当堆大小 > k的时候,就移除堆顶元素即可,这样就保证了,堆中的元素一直都是k,简化后的代码如下:

        // 让前k个元素入堆,创建大根堆
        for (int i = 0; i < arr.length; i++) {
            maxHeap.add(arr[i]);
            if (k < maxHeap.size()) {
                maxHeap.poll();
            }
        }

6、完整代码

附本文完整代码

package common.sort;

import java.util.Comparator;
import java.util.PriorityQueue;

public class TopK {
    public static void main(String[] args) {

        int[] arr = {5, 2, 9, 3, 8, 4, 1, 6, 7};
        // 方法一:使用普通的排序,
        // System.out.println(getSort(arr, 5));
        // 方法二:使用快速排序
        // System.out.println(quickSort(arr, 0, arr.length - 1, 5));
        // 方法三:使用时间复杂度高的堆(小根堆)
        // System.out.println(minHeap(arr,4));
        // 方法四:使用大根堆(创建k个元素的大根堆)
        // System.out.println(maxHeap(arr, 4));
        // 方法五:使用大根堆,在方法四基础上优化一下写法,作为最终模版方法
        System.out.println(maxHeapTemplate(arr, 4));
    }

    private static int maxHeapTemplate(int[] arr, int k) {
        // 创建大根堆
        PriorityQueue<Integer> maxHeap = new PriorityQueue<>((o1, o2) -> {
            return o2 - o1;
        });
        // 让前k个元素入堆,创建大根堆
        for (int i = 0; i < arr.length; i++) {
            maxHeap.add(arr[i]);
            if (k < maxHeap.size()) {
                maxHeap.poll();
            }
        }
        // 获取堆顶元素
        return maxHeap.peek();
    }

    // 方法四:换种思路,要求前k个最小元素,使用大根堆(创建k个元素的大根堆)
    // 为什么要建大根堆呢?因为大根堆可以保证k个元素中的最大值在堆顶,然后再将剩下的n-k个元素与堆顶元素比较:【如果该元素 < 堆顶元素,就将对顶元素删除,该元素入堆】
    private static int maxHeap(int[] arr, int k) {
        // 创建大根堆
        PriorityQueue<Integer> maxHeap = new PriorityQueue<>(new Comparator<Integer>() {
            @Override
            public int compare(Integer o1, Integer o2) {
                return o2 - o1;
            }
        });
        // 让前k个元素入堆,创建大根堆
        for (int i = 0; i < k; i++) {
            maxHeap.add(arr[i]);
        }
        // 遍历后面的n-k个元素
        for (int i = k; i < arr.length; i++) {
            if (arr[i] < maxHeap.peek()) {
                maxHeap.poll();
                maxHeap.add(arr[i]);
            }
        }
        // 获取堆顶元素
        return maxHeap.peek();
    }

    // 方法3:使用小根堆,思想,先让所有元素入优先级队列,然后弹出k个
    // 问题:创建堆的时间复杂度是O(n),删除的时间复杂度是O(k*log(n)),时间复杂度较高
    private static int minHeap(int[] arr, int k) {
        // 创建小根堆
        PriorityQueue<Integer> minHeap = new PriorityQueue<>();
        for (Integer data : arr) {
            minHeap.add(data);
        }
        int res = arr[0];
        for (int i = 1; i <= k; i++) {
            res = minHeap.poll();
        }
        return res;
    }


    // 方法2:通过快速排序改造
    private static int quickSort(int[] arr, int left, int right, int k) {
        if (left < right) {
            int pivot = getPivot(arr, left, right);

            // 添加逻辑(稍微改一下代码)
            if (pivot == k - 1) {
                return arr[pivot];
            } else if (k - 1 < pivot) {
                return quickSort(arr, left, pivot - 1, k);
            } else {
                return quickSort(arr, pivot + 1, right, k);
            }
        }
        return arr[left];
    }
/*    private static int[] quickSort(int[] arr, int left, int right, int k) {
        if (left < right) {
            int pivot = getPivot(arr, left, right);
            if (pivot == k - 1) {
                return Arrays.copyOfRange(arr, 0, k);
            } else if (pivot > k - 1) {
                return quickSort(arr, left, pivot - 1, k);
            } else {
                return quickSort(arr, pivot + 1, right, k);
            }
        }
        return Arrays.copyOfRange(arr, 0, k);
    }*/

    private static int getPivot(int[] arr, int left, int right) {
        // 1、 选取基准元素
        int pivotVal = arr[left];

        // 2、遍历
        while (left < right) {

            // 3、右边的往前移动
            while (left < right && arr[right] >= pivotVal) {
                right--;
            }
            arr[left] = arr[right];

            // 4、左边的往前移动
            while (left < right && arr[left] <= pivotVal) {
                left++;
            }
            arr[right] = arr[left];
        }

        // 5、重新定位基准
        arr[left] = pivotVal;
        return left;
    }


    // 方法1:通过排序找(局部排序)
    // 时间复杂度:O(n*k)
    /*
    private static int getSort(int[] arr, int k) {
        for (int i = 0; i < k; i++) {
            for (int j = i + 1; j < arr.length; j++) {
                if (arr[i] > arr[j]) {
                    int t = arr[i];
                    arr[i] = arr[j];
                    arr[j] = t;
                }
            }
        }
        return arr[k - 1];
    }
     */
}

评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值