TopK问题

import java.util.Arrays;
import java.util.PriorityQueue;
import java.util.Queue;

/**
 * 看起来分治法的快速选择算法的时间、空间复杂度都优于使用堆的方法,但是要注意到快速选择算法的几点局限性:
 * 第一,算法需要修改原数组,如果原数组不能修改的话,还需要拷贝一份数组,空间复杂度就上去了。
 * 第二,算法需要保存所有的数据。如果把数据看成输入流的话,使用堆的方法是来一个处理一个,不需要保存数据,只需要保存 k 个元素的最大堆。
 * 而快速选择的方法需要先保存下来所有的数据,再运行算法。当数据量非常大的时候,甚至内存都放不下的时候,就麻烦了。所以当数据量大的时候还是用基于堆的方法比较好。
 *
 */
public class TopK {

    public int[] getLeastNumbers(int[] arr, int k) {
        if(arr == null || arr.length == 0){
            return new int[0];
        }
        if(arr.length <= k){
            return arr;
        }
        quickSelect(arr,0,arr.length -1,k);
        int[] result = new int[k];
        for(int i=0;i<k;i++){
            result[i] = arr[i];
        }
        return result;
    }

    /**
     * 快速选择,和快排类似思想为分治法
     * 这个 partition 操作是原地进行的,需要 O(n) 的时间,接下来,快速排序会递归地排序左右两侧的数组。而快速选择(quick select)算法的不同之处在于,
     * 接下来只需要递归地选择一侧的数组。快速选择算法想当于一个“不完全”的快速排序,因为我们只需要知道最小的 k 个数是哪些,并不需要知道它们的顺序。
     * 我们的目的是寻找最小的 k 个数。假设经过一次 partition 操作,枢纽元素位于下标 mm,也就是说,左侧的数组有 mm 个元素,是原数组中最小的 mm 个数。那么:
     * 若 k = m,我们就找到了最小的 kk 个数,就是左侧的数组;
     * 若 k<m ,则最小的 k 个数一定都在左侧数组中,我们只需要对左侧数组递归地 parition 即可;
     * 若 k>m,则左侧数组中的 mm 个数都属于最小的 kk 个数,我们还需要在右侧数组中寻找最小的 k-mk−m 个数,对右侧数组递归地 partition 即可。
     *
     * 空间复杂度 O(1),不需要额外空间。
     * 时间复杂度的分析方法和快速排序类似。由于快速选择只需要递归一边的数组,时间复杂度小于快速排序,期望时间复杂度为 O(n),最坏情况下的时间复杂度为 O(n^2)
     *
     * @param arr
     * @param start
     * @param end
     * @param k
     */
    public void quickSelect(int[] arr,int start,int end,int k){
        int pivotIndex = partition(arr,start,end);
        if(pivotIndex == k){
            // 正好找到最小的 k(m) 个数
            return ;
        }else if(k < pivotIndex){
            // 最小的 k 个数一定在前 m 个数中,递归划分
            quickSelect(arr,start,pivotIndex - 1,k);
        }else {
            // 在右侧数组中寻找最小的 k-m 个数,两次递归调用传入的参数为什么都是 k?特别是第二个调用,
            // 我们在右侧数组中寻找最小的 k-m 个数,但是对于整个数组而言,这是最小的 k 个数。所以说,函数调用传入的参数应该为 k。
            quickSelect(arr,pivotIndex + 1,end,k);
        }
    }

    public int partition(int[] arr,int start,int end){
        int pivot = arr[start];
        int mark = start;
        for(int i=start + 1;i<=end;i++){
            if(arr[i] < pivot){
                mark ++;
                int temp = arr[mark];
                arr[mark] = arr[i];
                arr[i] = temp;
            }
        }
        arr[start] = arr[mark];
        arr[mark] = pivot;
        return mark;
    }

    /**
     * 使用堆,这里使用优先队列
     * 由于使用了一个大小为 k 的堆,空间复杂度为 O(k);
     * 入堆和出堆操作的时间复杂度均为 O(logk),每个元素都需要进行一次入堆操作,故算法的时间复杂度为 O(nlogk)
     *
     * @param
     */
    public int[] getLeastNumbersByHeap(int[] arr, int k) {
        if(k == 0 || arr == null){
            return new int[0];
        }
        if(arr.length <= k){
            return arr;
        }
        //使用一个最大堆,Java 的 PriorityQueue 默认是小顶堆,添加 comparator 参数使其变成最大堆
        Queue<Integer> heap = new PriorityQueue<>(k, (i1, i2) -> Integer.compare(i2, i1));
        for(int item:arr){
            // 当前数字小于堆顶元素才会入堆
            if(heap.isEmpty() || heap.size() < k || item < heap.peek()){
                heap.offer(item);
            }
            if(heap.size() > k){
                // 删除堆顶最大元素
                heap.poll();
            }
        }
        int[] result = new int[k];
        for(int i=0;i<k;i++){
            result[i] = arr[i];
        }
        return result;
    }

    public static void main(String[] args) {
        int[] arr = new int[]{3,2,9,8,4,0,7,5};
        int[] leastNums = new TopK().getLeastNumbers(arr,4);
        System.out.println(Arrays.toString(leastNums));

        arr = new int[]{3,2,9,8,4,0,7,5};
        leastNums = new TopK().getLeastNumbers(arr,4);
        System.out.println(Arrays.toString(leastNums));
    }

}

 

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值