回想一下算法基础中解决此问题的思路:(描述借鉴网络资源)
不妨设数列A元素个数为n,数列B元素个数为m,各自升序排序,求第k小元素
取A[k / 2] B[k / 2] 比较,
如果 A[k / 2] > B[k / 2] 那么,所求的元素必然不在B的前k / 2个元素中(证明反证法)
反之,必然不在A的前k / 2个元素中,于是我们可以将A或B数列的前k / 2元素删去,求剩下两个数列的
k - k / 2小元素,于是得到了数据规模变小的同类问题,递归解决
如果 k / 2 大于某数列个数,所求元素必然不在另一数列的前k / 2个元素中,同上操作就好。
需要特别注意的几个问题:
边界条件:
1,nums1长度为0时,直接返回nums2[k-1]
2,k==1,返回min(nums1[sk1], nums2[sk2]) (主要为了节省递归的时间)
3,k/2可能比m,n小
4,如果nums1[k1-1],nums2[k2-1]相等,返回其中一个值即可
#include
#include
#include
#include
using namespace std;
class Solution {
public:
/* 取A[k / 2] B[k / 2] 比较,
如果 A[k / 2] > B[k / 2] 那么,所求的元素必然不在B的前k / 2个元素中(证明反证法)
反之,必然不在A的前k / 2个元素中,于是我们可以将A或B数列的前k / 2元素删去,求剩下两个数列的
k - k / 2小元素,于是得到了数据规模变小的同类问题,递归解决
如果 k / 2 大于某数列个数,所求元素必然不在另一数列的前k / 2个元素中,同上操作就好。
*/
double findk(vector
& nums1, int sk1, vector
& nums2, int sk2, int k) { //向量不知道如何截取部分,使用sk表明起始位置 //因此长度的计算也变化了 int m = nums1.size() - sk1; int n = nums2.size() - sk2; //保证 m>n 可以在后面只需要判断m是否为0,以及求k的中间值时只考虑m是否不够长 if (m > n) return(findk(nums2, sk2, nums1, sk1, k)); if (m == 0) return nums2[sk2 + k - 1]; if (k == 1) { return min(nums1[sk1], nums2[sk2]); } int k1 = min(k / 2, m); int k2 = k - k1; //使用反证法思考比较容易想通 if (nums1[sk1+k1-1] < nums2[sk2+k2-1]) { return findk(nums1, sk1 + k1, nums2, sk2, k - k1); } else if (nums1[sk1+k1-1] > nums2[sk2+k2-1]) { return findk(nums1, sk1, nums2, sk2+k2 , k - k2); } //nums1与nums2的值相等 else { return nums1[k1-1]; } } double findMedianSortedArrays(vector
& nums1, vector
& nums2) { double result; int m = nums1.size(); int n = nums2.size(); int total = m + n; //判断total奇偶 if (total & 1) { result = findk(nums1, 0, nums2, 0, total / 2 + 1); } else { result = findk(nums1, 0, nums2, 0, total / 2) + findk(nums1, 0, nums2, 0, total / 2 + 1); result = result / 2.0; } return result; } }; void main(void) { Solution sol; int a[] = { 1,2 }; int b[] = { 1,2 }; int len1 = sizeof(a) / sizeof(a[0]); int len2 = sizeof(b) / sizeof(b[0]); //vector
nums1; //vector
nums2; vector
nums1(&a[0], &a[len1]); vector
nums2(&b[0], &b[len2]); double result; result = sol.findMedianSortedArrays(nums1, nums2); cout << result << endl; system("pause"); }