Top-K问题的几种求解方法

HeapSelect求解方法

使用这种方法求解,需要借助一种数据结构叫做。这里我们使用的是小根堆。

  • 首先把数组中的前k个值用来建一个小根堆。
  • 之后的其他数拿到之后进行判断,如果遇到比小顶堆的堆顶的值大的,将它放入堆中
  • 最终的堆会是数组中最大的K个数组成的结构,小根堆的顶部又是结构中的最小数,因此把堆顶的值弹出即可得到Top-K。
public class TopKHeapSelect {
    public int findKthLargest(int[] num, int k) {
        if (num.length < k) return 0;
        Queue<Integer> minQueue = new PriorityQueue<>(k + 1);
        for (int n : num) {
            if (minQueue.size() < k || n > minQueue.peek()) minQueue.offer(n);
            if (minQueue.size() > k) minQueue.poll();
        }
        return minQueue.peek();
    }
}

QuickSelect求解方法

1.运用快排的思想,将数组进行一次partition过程之后就可以得到两部分,求出第一部分长度,根据长度和k的关系,来选择递归的继续去转这个partition,知道长度和k相同,也就是得到了第Top-k的数。这个算法的时间复杂度在长期期望下是O(n)的,但是在最坏情况下是O(n^2)的一个算法。

import java.util.Random;

public class TopKQuickSelect {
    public int findKthLargest(int[] nums, int k) {
        if (nums == null || k < 0) return Integer.MAX_VALUE;
        int length = nums.length;
        return quickSelect(nums, k, 0, length - 1);
    }

    /**
     * 根据快排思想得到的Quickselect算法
     * @param nums
     * @param k
     * @param start
     * @param end
     * @return
     */
    private int quickSelect(int[] nums, int k, int start, int end) {
        Random random = new Random();
        int index = random.nextInt(end - start + 1) + start;
        int pivot = nums[index];
        swap(nums, index, end);

        int left = start, right = end;
        //下面这个while进行的是partition过程,因为这个partition过程简单,所以直接在这里写了
        //代码下面有一个partition函数,具体可以看那个函数的注释,理解整个partition过程
        while (left < right) {
            if (nums[left++] < pivot) {
                swap(nums, --left, --right);
            }
        }
        swap(nums, left, end);
        int tempK = left - start;
        if (tempK == k - 1) {
            return nums[left];
        } else if (k <= tempK) {
            return quickSelect(nums, k, start, left - 1);
        } else {
            return quickSelect(nums, k - tempK, left, end);
        }
    }

    /**
     * 荷兰国旗问题中的partition过程,小于区域放左边,等于区域放中间,大于区域放右边
     * end是呢个pivot
     * @param nums
     * @param start
     * @param end
     * @return
     */
    private int[] partition(int[] nums, int start, int end) {
        int less = start - 1;
        int more = end;
        while (start < end) {
            if (nums[start] < nums[end]) {
                swap(nums, ++less, start++);
            } else if (nums[start] > nums[end]) {
                swap(nums, --more, end);
            } else {
                start++;
            }
        }
        return new int[] {less, more};
    }

    /**
     * swap
     * @param arr --array
     * @param i --index1
     * @param j --index2
     */
    private void swap(int[] arr, int i, int j) {
        int tmp = arr[i];
        arr[i] = arr[j];
        arr[j] = tmp;
    }
}

BFPRT算法求解

  • BFPRT算法天生就是用来做Top-K问题的,这个算法能够在最坏情况下也做到O(n)的时间复杂度解决Top-K问题,wiki上把这个算法称为Median of medians,实在不知道怎么翻译,直译是中位数的中位数,还挺迷的翻译,摔 :)
  • 下面这段话引自wiki,是对这么5个大佬,在1973年发布的Time bounds for selection的paper中发明了这个算法。
  • M.; Floyd, R. W.; Pratt, V. R.; Rivest, R. L.; Tarjan, R. E. (August 1973). “Time bounds for selection” (PDF). Journal of Computer and System Sciences. 7 (4): 448–461. doi:10.1016/S0022-0000(73)80033-9.

  • BFPRT算法要解决的问题是什么? 其实就是为了选择一个更好的元素,这个元素就是像算法名一样叫做Median of medians ,这个数其实也不是最中间的一个数,而这个分割的比例在30%到70%之间,这样就可以严格的保证在最差的情况下,这个算法依然可以达到严格意义上的O(n),wiki上给的结果是3.33n

package acmcoder;

public class BFPRT {
    // O(N)
    private static int getMinKthByBFPRT(int[] arr, int K) {
        int[] copyArr = copyArray(arr);
        return select(copyArr, 0, copyArr.length - 1, K - 1);
    }

    public static int[] copyArray(int[] arr) {
        int[] res = new int[arr.length];
        for (int i = 0; i != res.length; i++) {
            res[i] = arr[i];
        }
        return res;
    }

    private static int select(int[] arr, int begin, int end, int i) {
        if (begin == end) {
            return arr[begin];
        }
        //求中位数的中位数
        int pivot = medianOfMedians(arr, begin, end);
        //用中位数的中位数做一次partition
        int[] pivotRange = partition(arr, begin, end, pivot);
        //如果命中直接返回
        if (i >= pivotRange[0] && i <= pivotRange[1]) {
            return arr[i];
        } else if (i < pivotRange[0]) {
            return select(arr, begin, pivotRange[0] - 1, i);
        } else {
            return select(arr, pivotRange[1] + 1, end, i);
        }
    }

    private static int medianOfMedians(int[] arr, int begin, int end) {
        int num = end - begin + 1;
        int offset = num % 5 == 0 ? 0 : 1;
        int[] mArr = new int[num / 5 + offset];
        for (int i = 0; i < mArr.length; i++) {
            int beginI = begin + i * 5;
            int endI = beginI + 4;
            mArr[i] = getMedian(arr, beginI, Math.min(end, endI));
        }
        return select(mArr, 0, mArr.length - 1, mArr.length / 2);
    }

    private static int[] partition(int[] arr, int begin, int end, int pivotValue) {
        int small = begin - 1;
        int cur = begin;
        int big = end + 1;
        while (cur != big) {
            if (arr[cur] > pivotValue) {
                swap(arr, ++small, cur++);
            } else if (arr[cur] < pivotValue) {
                swap(arr, cur, --big);
            } else {
                cur++;
            }
        }
        int[] range = new int[2];
        range[0] = small + 1;
        range[1] = big - 1;
        return range;
    }

    private static int getMedian(int[] arr, int begin, int end) {
        insertionSort(arr, begin, end);
        int sum = end + begin;
        int mid = (sum / 2) + (sum % 2);
        return arr[mid];
    }

    private static void insertionSort(int[] arr, int begin, int end) {
        for (int i = begin + 1; i != end + 1; i++) {
            for (int j = i; j != begin; j--) {
                if (arr[j - 1] > arr[j]) {
                    swap(arr, j - 1, j);
                } else {
                    break;
                }
            }
        }
    }

    public static void swap(int[] arr, int index1, int index2) {
        int tmp = arr[index1];
        arr[index1] = arr[index2];
        arr[index2] = tmp;
    }

    public static void printArray(int[] arr) {
        for (int i = 0; i != arr.length; i++) {
            System.out.print(arr[i] + " ");
        }
        System.out.println();
    }

    public static void main(String[] args) {
        int[] nu = {3,90,56,20,20,20,20,46,72};
        //3 20 20 20 20 46 56 72 90
        System.out.println(getMinKthByBFPRT(nu, 8));
    }

}
  • 2
    点赞
  • 3
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值