BFPRT算法的主要步骤和代码实现
解决的问题
求一个无序数组中第k小的数。约定:k是从1开始计数的,即最小的那个数是第一小的数。
解决方案
1、快速排序:平均时间复杂度O(nlogn),最坏时间复杂度O(n2)
2、堆:时间复杂度O(nlogk)
3、快速选择:平均时间复杂度O(n),最坏时间复杂度O(n2)
例如:1, 2, 3, 4, 5
,如果要求最5小的数,使用快速选择时间复杂度为O(n2),原因是基准值选得不好,使得每次根据基准值划分的时候,其它的所有数字都被分到了一边。
第一轮以1
作为基准值,小于基准值的放在左边,大于基准值的放在右边,得到的结果是1, 2, 3, 4, 5
,发现目标数在基准值1
的右边;第二轮以2
作为基准值,得到的结果依旧是1, 2, 3, 4, 5
;直到找到目标值。
优化:每次不以[left, right]
范围内left位置的元素作为基准值,而是从left到right范围内随机选择一个作为基准值,通过随机化的方式来优化时间复杂度,在概率统计上,此时的时间复杂度为O(n)。
4、BFPRT算法
时间复杂度为严格的O(N),而不是概率统计上的O(n)。BFPRT算法与快速选择最主要的区别在于选取基准值的方式不同,选取基准值之后后续的步骤都是一样的。所以理解BFPRT算法前提是要理解快速选择的过程。
BFPRT算法的步骤
- 以五个元素作为一组,对原数组进行分组,最后一个分组如果不足5个元素,依旧可以分为一个组;
- 各个组进行组内排序;
- 取出各个组的中位数,组成一个新的数组,暂且叫做【中位数数组】,最后一个组的元素个数若是偶数,去该组的上中位数或者下中位数都可以;
- 递归调用BFPRT算法来求中位数数组的中位数;
- 以上一步求得的中位数作为基准值pivot,对原数组进行partiton过程(快速选择)
- 根据基准值和目标数的位置关系,如果未命中目标数,回到第一步重新执行。
对第四步的解释:
- 为什么不能直接求出来中位数数组的中位数?因为这个中位数数组里面的元素是无序的,不要被名字所迷惑;
- 为什么可以递归调用BFPRT算法来求中位数?首先这个中位数数组的数据规模肯定是比原数组小的,其次求解问题的语义是相同的,外层的BFPRT算法是求原数组的第k小的元素,调用过程是这样的
int bfprt(int[] arr, int k)
,如果要求中位数数组middleArr的中位数,就转化为求这个数组的第middleArr.length/2
小的元素,所以说语义是一样的,内层的BFPRT算法调用过程就是int bfprt(int[] middleArr, int middleArr.length/2)
。
BFPRT算法的时间复杂度分析
假设BFPRT算法的数据规模是T(n);
- 第一步分组的时间复杂度:O(1);
- 第二步组内排序的时间复杂度:O(n);
- 第三步组成中位数数组的时间复杂度:O(n);
- 第四步递归调用BFPRT算法求中位数数组的中位数的时间复杂度:T(n/5);
- 第五步partition的时间复杂度:O(n);
- 第六步在未命中目标数的情况下,考虑最坏的情况:即左侧最多有多少个元素。反过来考虑,右侧最少有多少个元素。
中位数数组的长度是
n
5
\frac{n}{5}
5n ,选取中位数数组的中位数作为基准值,那么在这个中位数数组中,至少有
n
5
∗
1
2
=
n
10
\frac{n}{5}*\frac{1}{2}=\frac{n}{10}
5n∗21=10n 个元素比基准值大,这
n
10
\frac{n}{10}
10n个元素在各自的组中还有两个元素比它们自己大,所以在原数组中至少有
3
n
10
\frac{3n}{10}
103n 的元素比基准值大,也就是说最多有
7
n
10
\frac{7n}{10}
107n 的元素比基准值小,这些元素位于基准值的左侧。最终BFPRT算法的时间复杂度表示为:
T
(
n
)
=
T
(
n
5
)
+
T
(
7
n
10
)
+
O
(
n
)
T(n)=T(\frac{n}{5})+T(\frac{7n}{10})+O(n)
T(n)=T(5n)+T(107n)+O(n)
上面这个式子可以求出时间复杂度为O(n)。(证明过程见算法导论第九章)
BFPRT算法的代码实现
import java.util.Arrays;
import java.util.Random;
public class BFPRT {
/**
* 在无序数组arr中求出第k小的数
* @param arr 无序数组
* @param k 表示第k小,注意:k是从1开始计数的
* @return
*/
public static int getMinKthByBFPRT(int[] arr, int k) {
int[] copyArr = copyArray(arr);
return bfprt(copyArr, 0, copyArr.length-1, k-1);
}
private static int[] copyArray(int[] arr) {
int[] res = new int[arr.length];
for (int i = 0; i < res.length; i++) {
res[i] = arr[i];
}
return res;
}
/**
* 对于arr[left, right],使用bfprt算法求出第i小的数
* @param arr
* @param left
* @param right
* @param i 这里的i表示的是索引
* @return
*/
private static int bfprt(int[] arr, int left, int right, int i) {
if (left == right) {
return arr[left];
}
int pivot = getMedianOfMedians(arr, left, right);
int[] pivotRange = partition(arr, left, right, pivot);
if (i >= pivotRange[0] && i <= pivotRange[1]) {
return arr[i];
} else if (i < pivotRange[0]) {
return bfprt(arr, left, pivotRange[0]-1, i);
} else {
return bfprt(arr, pivotRange[1]+1, right, i);
}
}
/**
* 对arr[left, right]这一部分,进行分组,组内排序,组成中位数数组,返回中位数数组的中位数
* @param arr
* @param left
* @param right
* @return 中位数数组的中位数
*/
private static int getMedianOfMedians(int[] arr, int left, int right) {
int nums = right - left + 1;
int offset = nums % 5 == 0 ? 0 : 1;
int[] medians = new int[nums / 5 + offset]; // 中位数数组
for (int i = 0; i < medians.length; i++) {
// 确定这一段在原数组arr中的索引
int begin = left + i * 5;
int end = Math.min(begin+4, right);
medians[i] = getMedian(arr, begin, end);
}
// 调用bfprt算法求中位数数组的中位数
return bfprt(medians, 0, medians.length-1, medians.length/2);
}
private static int[] partition(int[] arr, int left, int right, int pivot) {
int small = left - 1;
int big = right + 1;
int cur = left;
while (cur < big) {
if (arr[cur] < pivot) {
swap(arr, ++small, cur++);
} else if (arr[cur] > pivot) {
swap(arr, cur, --big);
} else {
cur++;
}
}
int[] pivotRange = new int[2];
pivotRange[0] = small+1;
pivotRange[1] = big-1;
return pivotRange;
}
private static void swap(int[] nums, int i, int j) {
int copy = nums[i];
nums[i] = nums[j];
nums[j] = copy;
}
private static int getMedian(int[] arr, int left, int right) {
Arrays.sort(arr, left, right);
int sum = left + right;
int mid = (sum / 2) + (sum % 2); // 这里约定最后一组不满5个数时取上中位数
return arr[mid];
}
public static void main(String[] args) {
Random random = new Random();
for (int i = 0; i < 100; i++) {
int[] arr = new int[10];
for (int j = 0; j < arr.length; j++) {
arr[j] = random.nextInt(10);
}
int k = 5;
int res = getMinKthByBFPRT(arr, k);
Arrays.sort(arr);
if (res != arr[k-1]) {
System.out.println("ERROR! bfprt is " + res + ", sort is " + arr[k-1]);
}
}
}
}
参考资料:左神算法课程。