本文解题方法参考油管博主Tushar Roy,链接如下:
Binary Search : Median of two sorted arrays of different sizes.
问题重述
给定两个大小分别为m和n的排好序的数组X和Y,找到这两个数组的中位数。
例:
输入:X = [1, 2],Y = [3, 4]
输出:2.5
注:本文采用C++实现
解答
这道题最简单的暴力解法是直接将两个数组先拼接在一起,然后排序求出中位数即可。具体实现代码如下:
double findMedianSortedArrays(vector<int>& nums1, vector<int>& nums2) {
auto len = nums1.size() + nums2.size();
vector<int> nums = nums1;
for(int i = 0; i < nums2.size(); i++)
nums.push_back(nums2[i]);
sort(nums.begin(),nums.end());
if(len%2 == 1)
return (double)nums[len/2];
else
return ((double)(nums[len/2]+nums[len/2-1]))/2;
}
下面来看看如何实现时间复杂度O(log(min(m,n)))的解法。
时间复杂度和log相关,最直接的想法就是尝试能不能用二分法来解决问题。
进一步分析发现,对于一个已经排好序的数组来说,我们只要将它从中间一分为二就能找到其中位数。那对于两个排好序的数组,如何使用二分法来找到它的中位数呢?
下面来看例子
X = [1,4,6,8,10]
Y = [2,3,7,9,12,14,15]
则这两个数组合并后是一个12个数字的有序数组,中位数为7.5。
此时,对于单独的X,Y数组而言,同样可以利用和7.5的大小关系将其分割为左右两半,则存在规律:
- X中左一半的数字个数加Y左一半的数字个数之和刚好是6,也就是合并后这个偶数个数组长度的一半
- 由于X,Y均已排好序,则X左一半的最后一个数必然小于Y右一半的第一个数,对Y的左一半和X的右一半有着同样的关系。
以上述两个规律为出发点,我们在不合并数组的情况下求取其中位数。
X长度 len_1 = 5,Y长度 len_2 = 7
假设X中左一半长度为find_x,Y中左一半长度为find_y
定义 lo = 0, hi = len_1,结合二分法来确定find_x,后根据规律1确定find_y,利用规律2来判断是否找到正确位置。
第一次循环:
find_x = (lo+hi)/2 = 2
find_y = (len_1+len_2)/2 - find_x = 4
此时X[find_x - 1] < X[find_y], 但X[find_y - 1] > X[find_x]
说明Y左边分配的数字多了一点,find_x应该增加,故令 lo = find_x +1 = 3
第二次循环
find_x = (lo+hi)/2 = 4
find_y = (len_1+len_2)/2 - find_x = 2
此时X[find_x - 1] > X[find_y], X[find_y - 1] < X[find_x]
说明X左边分配的数字多了一点,find_x应该减小,故令 hi = find_x -1 = 3
第三次循环
find_x = (lo+hi)/2 = 3
find_y = (len_1+len_2)/2 - find_x = 3
此时X[find_x - 1] < X[find_y], X[find_y - 1] < X[find_x]
满足规律2,说明找到中位数 mid = (max(X[find_x - 1],X[find_y - 1]) + min(X[find_x],X[find_y]))/2
下面贴上代码
int findMax(int a, int b)
{
return a>b?a:b;
}
int findMin(int a, int b)
{
return a>b?b:a;
}
double findMedianSortedArrays(vector<int> &nums1, vector<int> &nums2)
{
if (nums1.size() > nums2.size())
return findMedianSortedArrays(nums2, nums1);
int x = nums1.size();
int y = nums2.size();
int lo = 0, hi = x;
while (lo <= hi)
{
int partionX = (lo + hi) / 2;
int partionY = (x + y + 1) / 2 - partionX;
int maxLeftX = (partionX == 0) ? INT_MIN : nums1[partionX - 1];
int minRightX = (partionX == x) ? INT_MAX : nums1[partionX];
int maxLeftY = (partionY == 0) ? INT_MIN : nums2[partionY - 1];
int minRightY = (partionY == y) ? INT_MAX : nums2[partionY];
if (maxLeftX <= minRightY && maxLeftY <= minRightX)
{
if ((x + y) % 2 == 0)
return (double)(findMax(maxLeftX, maxLeftY) + findMin(minRightX, minRightY)) / 2;
else
return (double)findMax(maxLeftX, maxLeftY);
}
else if (maxLeftX > minRightY)
hi = partionX - 1;
else
lo = partionX + 1;
}
return -1;
}
PS:代码中需要注意的细节有:
- 利用nums1.size() > nums2.size()的判断来保证时间复杂度为O(log(min(m,n)))
- 数组长度之和为奇数和偶数时处理方法的异同
- 避免数组越界访问的实现