HeapSelect求解方法
使用这种方法求解,需要借助一种数据结构叫做堆。这里我们使用的是小根堆。
- 首先把数组中的前k个值用来建一个小根堆。
- 之后的其他数拿到之后进行判断,如果遇到比小顶堆的堆顶的值大的,将它放入堆中
- 最终的堆会是数组中最大的K个数组成的结构,小根堆的顶部又是结构中的最小数,因此把堆顶的值弹出即可得到Top-K。
public class TopKHeapSelect {
public int findKthLargest(int[] num, int k) {
if (num.length < k) return 0;
Queue<Integer> minQueue = new PriorityQueue<>(k + 1);
for (int n : num) {
if (minQueue.size() < k || n > minQueue.peek()) minQueue.offer(n);
if (minQueue.size() > k) minQueue.poll();
}
return minQueue.peek();
}
}
QuickSelect求解方法
1.运用快排的思想,将数组进行一次partition过程之后就可以得到两部分,求出第一部分长度,根据长度和k的关系,来选择递归的继续去转这个partition,知道长度和k相同,也就是得到了第Top-k的数。这个算法的时间复杂度在长期期望下是O(n)的,但是在最坏情况下是O(n^2)的一个算法。
import java.util.Random;
public class TopKQuickSelect {
public int findKthLargest(int[] nums, int k) {
if (nums == null || k < 0) return Integer.MAX_VALUE;
int length = nums.length;
return quickSelect(nums, k, 0, length - 1);
}
/**
* 根据快排思想得到的Quickselect算法
* @param nums
* @param k
* @param start
* @param end
* @return
*/
private int quickSelect(int[] nums, int k, int start, int end) {
Random random = new Random();
int index = random.nextInt(end - start + 1) + start;
int pivot = nums[index];
swap(nums, index, end);
int left = start, right = end;
//下面这个while进行的是partition过程,因为这个partition过程简单,所以直接在这里写了
//代码下面有一个partition函数,具体可以看那个函数的注释,理解整个partition过程
while (left < right) {
if (nums[left++] < pivot) {
swap(nums, --left, --right);
}
}
swap(nums, left, end);
int tempK = left - start;
if (tempK == k - 1) {
return nums[left];
} else if (k <= tempK) {
return quickSelect(nums, k, start, left - 1);
} else {
return quickSelect(nums, k - tempK, left, end);
}
}
/**
* 荷兰国旗问题中的partition过程,小于区域放左边,等于区域放中间,大于区域放右边
* end是呢个pivot
* @param nums
* @param start
* @param end
* @return
*/
private int[] partition(int[] nums, int start, int end) {
int less = start - 1;
int more = end;
while (start < end) {
if (nums[start] < nums[end]) {
swap(nums, ++less, start++);
} else if (nums[start] > nums[end]) {
swap(nums, --more, end);
} else {
start++;
}
}
return new int[] {less, more};
}
/**
* swap
* @param arr --array
* @param i --index1
* @param j --index2
*/
private void swap(int[] arr, int i, int j) {
int tmp = arr[i];
arr[i] = arr[j];
arr[j] = tmp;
}
}
BFPRT算法求解
- BFPRT算法天生就是用来做Top-K问题的,这个算法能够在最坏情况下也做到O(n)的时间复杂度解决Top-K问题,wiki上把这个算法称为Median of medians,实在不知道怎么翻译,直译是中位数的中位数,还挺迷的翻译,摔 :)
- 下面这段话引自wiki,是对这么5个大佬,在1973年发布的Time bounds for selection的paper中发明了这个算法。
M.; Floyd, R. W.; Pratt, V. R.; Rivest, R. L.; Tarjan, R. E. (August 1973). “Time bounds for selection” (PDF). Journal of Computer and System Sciences. 7 (4): 448–461. doi:10.1016/S0022-0000(73)80033-9.
BFPRT算法要解决的问题是什么? 其实就是为了选择一个更好的元素,这个元素就是像算法名一样叫做Median of medians ,这个数其实也不是最中间的一个数,而这个分割的比例在30%到70%之间,这样就可以严格的保证在最差的情况下,这个算法依然可以达到严格意义上的O(n),wiki上给的结果是3.33n
package acmcoder;
public class BFPRT {
// O(N)
private static int getMinKthByBFPRT(int[] arr, int K) {
int[] copyArr = copyArray(arr);
return select(copyArr, 0, copyArr.length - 1, K - 1);
}
public static int[] copyArray(int[] arr) {
int[] res = new int[arr.length];
for (int i = 0; i != res.length; i++) {
res[i] = arr[i];
}
return res;
}
private static int select(int[] arr, int begin, int end, int i) {
if (begin == end) {
return arr[begin];
}
//求中位数的中位数
int pivot = medianOfMedians(arr, begin, end);
//用中位数的中位数做一次partition
int[] pivotRange = partition(arr, begin, end, pivot);
//如果命中直接返回
if (i >= pivotRange[0] && i <= pivotRange[1]) {
return arr[i];
} else if (i < pivotRange[0]) {
return select(arr, begin, pivotRange[0] - 1, i);
} else {
return select(arr, pivotRange[1] + 1, end, i);
}
}
private static int medianOfMedians(int[] arr, int begin, int end) {
int num = end - begin + 1;
int offset = num % 5 == 0 ? 0 : 1;
int[] mArr = new int[num / 5 + offset];
for (int i = 0; i < mArr.length; i++) {
int beginI = begin + i * 5;
int endI = beginI + 4;
mArr[i] = getMedian(arr, beginI, Math.min(end, endI));
}
return select(mArr, 0, mArr.length - 1, mArr.length / 2);
}
private static int[] partition(int[] arr, int begin, int end, int pivotValue) {
int small = begin - 1;
int cur = begin;
int big = end + 1;
while (cur != big) {
if (arr[cur] > pivotValue) {
swap(arr, ++small, cur++);
} else if (arr[cur] < pivotValue) {
swap(arr, cur, --big);
} else {
cur++;
}
}
int[] range = new int[2];
range[0] = small + 1;
range[1] = big - 1;
return range;
}
private static int getMedian(int[] arr, int begin, int end) {
insertionSort(arr, begin, end);
int sum = end + begin;
int mid = (sum / 2) + (sum % 2);
return arr[mid];
}
private static void insertionSort(int[] arr, int begin, int end) {
for (int i = begin + 1; i != end + 1; i++) {
for (int j = i; j != begin; j--) {
if (arr[j - 1] > arr[j]) {
swap(arr, j - 1, j);
} else {
break;
}
}
}
}
public static void swap(int[] arr, int index1, int index2) {
int tmp = arr[index1];
arr[index1] = arr[index2];
arr[index2] = tmp;
}
public static void printArray(int[] arr) {
for (int i = 0; i != arr.length; i++) {
System.out.print(arr[i] + " ");
}
System.out.println();
}
public static void main(String[] args) {
int[] nu = {3,90,56,20,20,20,20,46,72};
//3 20 20 20 20 46 56 72 90
System.out.println(getMinKthByBFPRT(nu, 8));
}
}