在无序数组中找到第k大的数
1)分组,每N个数一组,(一般5个一组)
2)每组分别进行排序,组间不排序
3)将每个组的中位数拿出来,若偶数,则拿上 / 下中位数, 成立一个一个新数组。
4)新数组递归调用BFPRT,则拿到整体的中位数num
5)以num来划分整体数组,小于在左,大于在右边,使用【荷兰国旗方法】
6)然后根据左右数组的规模,来确定进一步选择左右哪一部分;
7)然后选择好后,继续
一:背景介绍
在一大堆数中求其前k大或前k小的问题,简称TOP - K问题。而目前解决TOP - K问题最有效的算法即是BFPRT算法,其又称为中位数的中位数算法,该算法由Blum、Floyd、Pratt、Rivest、Tarjan提出,最坏时间复杂度为O(n)O(n)。
在首次接触TOP - K问题时,我们的第一反应就是可以先对所有数据进行一次排序,然后取其前k即可,但是这么做有两个问题:
(1):快速排序的平均复杂度为O(nlogn),但最坏时间复杂度为O(n2),不能始终保证较好的复杂度。
(2):我们只需要前k大的,而对其余不需要的数也进行了排序,浪费了大量排序时间。
除这种方法之外,堆排序也是一个比较好的选择,可以维护一个大小为k的堆,时间复杂度为O(nlogk)。
那是否还存在更有效的方法呢?受到快速排序的启发,通过修改快速排序中主元的选取方法可以降低快速排序在最坏情况下的时间复杂度(即BFPRT算法),并且我们的目的只是求出前k,故递归的规模变小,速度也随之提高。下面来简单回顾下快速排序的过程,以升序为例:
(1):选取主元(首元素,尾元素或一个随机元素);
(2):以选取的主元为分界点,把小于主元的放在左边,大于主元的放在右边;
(3):分别对左边和右边进行递归,重复上述过程。
二:BFPRT算法过程及代码
BFPRT算法步骤如下:
(1):选取主元;
(1.1):将n个元素划分为[n / 5]个组,每组5个元素,若有剩余,舍去;
(1.2):使用插入排序找到[n / 5]个组中每一组的中位数;
(1.3):对于(1.2)中找到的所有中位数,调用BFPRT算法求出它们的中位数,作为主元;
(2):以(1.3)选取的主元为分界点,把小于主元的放在左边,大于主元的放在右边;
(3):判断主元的位置与k的大小,有选择的对左边或右边递归。//即根据K的位置判断选择哪一部分继续迭代
上面的描述可能并不易理解,先看下面这幅图:
BFPRT()调用GetPivotIndex()和Partition()来求解第k小,在这过程中,GetPivotIndex()也调用了BFPRT(),即GetPivotIndex)和BFPRT()为互递归的关系。
下面为代码实现,其所求为前K小的数:
C++代码:
1 /******C++**********/ 2 #include<iostream> 3 #include<algorithm> 4 using namespace std; 5 6 int InsertSort(int array[], int left, int right); //插入排序,返回中位数下标 7 int GetPivotIndex(int array[], int left, int right); //返回中位数的中位数下标 8 int Partition(int array[], int left, int right, int pivot_index); //利用中位数的中位数的下标进行划分,返回分界线下标 9 int BFPRT(int array[], int left, int right, const int & k); //求第k小,返回其位置的下标 10 11 int main() 12 { 13 int k = 5; 14 int array[10] = { 1,1,2,3,1,5,-1,7,8,-10 }; 15 16 cout << "原数组:"; 17 for (int i = 0; i < 10; i++) 18 cout << array[i] << " "; 19 cout << endl; 20 21 cout << "第" << k << "小值为:" << array[BFPRT(array, 0, 9, k)] << endl; 22 23 cout << "变换后的数组:"; 24 for (int i = 0; i < 10; i++) 25 cout << array[i] << " "; 26 cout << endl; 27 28 return 0; 29 } 30 31 /* 插入排序,返回中位数下标 */ 32 int InsertSort(int array[], int left, int right) 33 { 34 int temp; 35 int j; 36 for (int i = left + 1; i <= right; i++) 37 { 38 temp = array[i]; 39 j = i - 1; 40 while (j >= left && array[j] > temp) 41 array[j + 1] = array[j--]; 42 array[j + 1] = temp; 43 } 44 45 return ((right - left) >> 1) + left; 46 } 47 48 /* 返回中位数的中位数下标 */ 49 int GetPivotIndex(int array[], int left, int right) 50 { 51 if (right - left < 5) 52 return InsertSort(array, left, right); 53 54 int sub_right = left - 1; 55 for (int i = left; i + 4 <= right; i += 5) 56 { 57 int index = InsertSort(array, i, i + 4); //找到五个元素的中位数的下标 58 swap(array[++sub_right], array[index]); //依次放在左侧 59 } 60 61 return BFPRT(array, left, sub_right, ((sub_right - left + 1) >> 1) + 1); 62 } 63 64 /* 利用中位数的中位数的下标进行划分,返回分界线下标 */ 65 int Partition(int array[], int left, int right, int pivot_index) 66 { 67 swap(array[pivot_index], array[right]); //把基准放置于末尾 68 69 int divide_index = left; //跟踪划分的分界线 70 for (int i = left; i < right; i++) 71 { 72 if (array[i] < array[right]) 73 swap(array[divide_index++], array[i]); //比基准小的都放在左侧 74 } 75 76 swap(array[divide_index], array[right]); //最后把基准换回来 77 return divide_index; 78 } 79 80 int BFPRT(int array[], int left, int right, const int & k) 81 { 82 int pivot_index = GetPivotIndex(array, left, right); //得到中位数的中位数下标 83 int divide_index = Partition(array, left, right, pivot_index); //进行划分,返回划分边界 84 int num = divide_index - left + 1; 85 if (num == k) 86 return divide_index; 87 else if (num > k) 88 return BFPRT(array, left, divide_index - 1, k); 89 else 90 return BFPRT(array, divide_index + 1, right, k - num); 91 }
Java代码:
1 public class BFPRT { 2 //前k小 3 public static int[] getMinKNumsByBFPRT(int[] arr, int k) { 4 if (k < 1 || k > arr.length) { 5 return arr; 6 } 7 int minKth = getMinKthByBFPRT(arr, k); 8 int[] res = new int[k]; 9 int index = 0; 10 for (int i = 0; i != arr.length; i++) { 11 if (arr[i] < minKth) { 12 res[index++] = arr[i]; 13 } 14 } 15 for (; index != res.length; index++) { 16 res[index] = minKth; 17 } 18 return res; 19 } 20 //第k小 21 public static int getMinKthByBFPRT(int[] arr, int K) { 22 int[] copyArr = copyArray(arr); 23 return select(copyArr, 0, copyArr.length - 1, K - 1); 24 } 25 26 public static int[] copyArray(int[] arr) { 27 int[] res = new int[arr.length]; 28 for (int i = 0; i != res.length; i++) { 29 res[i] = arr[i]; 30 } 31 return res; 32 } 33 //给定一个数组和范围,求第i小的数 34 public static int select(int[] arr, int begin, int end, int i) { 35 if (begin == end) { 36 return arr[begin]; 37 } 38 int pivot = medianOfMedians(arr, begin, end);//划分值 39 int[] pivotRange = partition(arr, begin, end, pivot); 40 if (i >= pivotRange[0] && i <= pivotRange[1]) { 41 return arr[i]; 42 } 43 else if (i < pivotRange[0]) { 44 return select(arr, begin, pivotRange[0] - 1, i); 45 } 46 else { 47 return select(arr, pivotRange[1] + 1, end, i); 48 } 49 } 50 //在begin end范围内进行操作 51 public static int medianOfMedians(int[] arr, int begin, int end) { 52 int num = end - begin + 1; 53 int offset = num % 5 == 0 ? 0 : 1;//最后一组的情况 54 int[] mArr = new int[num / 5 + offset];//中位数组成的数组 55 for (int i = 0; i < mArr.length; i++) { 56 int beginI = begin + i * 5; 57 int endI = beginI + 4; 58 mArr[i] = getMedian(arr, beginI, Math.min(end, endI)); 59 } 60 return select(mArr, 0, mArr.length - 1, mArr.length / 2); 61 //只不过i等于长度一半,用来求中位数 62 } 63 //经典partition过程 64 public static int[] partition(int[] arr, int begin, int end, int pivotValue) { 65 int small = begin - 1; 66 int cur = begin; 67 int big = end + 1; 68 while (cur != big) { 69 if (arr[cur] < pivotValue) { 70 swap(arr, ++small, cur++); 71 } 72 else if (arr[cur] > pivotValue) { 73 swap(arr, cur, --big); 74 } 75 else { 76 cur++; 77 } 78 } 79 int[] range = new int[2]; 80 range[0] = small + 1; 81 range[1] = big - 1; 82 return range; 83 } 84 //五个数排序,返回中位数 85 public static int getMedian(int[] arr, int begin, int end) { 86 insertionSort(arr, begin, end); 87 int sum = end + begin; 88 int mid = (sum / 2) + (sum % 2); 89 return arr[mid]; 90 } 91 //手写排序 92 public static void insertionSort(int[] arr, int begin, int end) { 93 for (int i = begin + 1; i != end + 1; i++) { 94 for (int j = i; j != begin; j--) { 95 if (arr[j - 1] > arr[j]) { 96 swap(arr, j - 1, j); 97 } 98 else { 99 break; 100 } 101 } 102 } 103 } 104 //交换值 105 public static void swap(int[] arr, int index1, int index2) { 106 int tmp = arr[index1]; 107 arr[index1] = arr[index2]; 108 arr[index2] = tmp; 109 } 110 //打印 111 public static void printArray(int[] arr) { 112 for (int i = 0; i != arr.length; i++) { 113 System.out.print(arr[i] + " "); 114 } 115 System.out.println(); 116 } 117 118 public static void main(String[] args) { 119 int[] arr = { 6, 9, 1, 3, 1, 2, 2, 5, 6, 1, 3, 5, 9, 7, 2, 5, 6, 1, 9 }; 120 // sorted : { 1, 1, 1, 1, 2, 2, 2, 3, 3, 5, 5, 5, 6, 6, 6, 7, 9, 9, 9 } 121 printArray(getMinKNumsByBFPRT(arr, 10)); 122 123 } 124 } 125