题目
There are two sorted arrays nums1 and nums2 of size m and n respectively.
Find the median of the two sorted arrays. The overall run time complexity should be O(log (m+n)).
You may assume nums1 and nums2 cannot be both empty.
Example 1:
nums1 = [1, 3]
nums2 = [2]
The median is 2.0
Example 2:
nums1 = [1, 2]
nums2 = [3, 4]
The median is (2 + 3)/2 = 2.5
分析
这个题目的意思是给出两个已排好序的数组,要求我们给出包含两个数组中所有数的中位数。我首先想到的方法是利用两个循环变量index1, index2
来同时遍历这两个数组 nums1、nums2
,当 nums1[index1] < nums2[index2]
时index1++
,否则 index2++
,并且选出小的那个数,直到总共遍历了(m + n) / 2
个元素,有点类似于合并两个有序链表的算法。此时就可以找到中位数了:
- 如果总数是奇数,那么中位数就是最后一个选出来的数
- 如果总数是偶数,那么中位数就是最后两个选出来的数的平均数
代码如下:
class Solution {
public:
double findMedianSortedArrays(vector<int>& nums1, vector<int>& nums2) {
int m = nums1.size();
int n = nums2.size();
int mid = (m + n) / 2;
int index1 = 0, index2 = 0;
int old_tmp = 0, new_tmp = 0;;
while ((index1 + index2) <= mid) {
old_tmp = new_tmp;
if (index1 == m) new_tmp = nums2[index2++];
else if (index2 == n) new_tmp = nums1[index1++];
else if (nums1[index1] < nums2[index2]) new_tmp = nums1[index1++];
else new_tmp = nums2[index2++];
}
return (m + n) % 2 == 0 ? (old_tmp + new_tmp) / 2.0 : new_tmp;
}
};
这种方法的时间复杂度是 O((m + n) / 2)
, 并不满足题目要求的O(log(m + n))
,因此还是得找效率更高的方法。
解法一
实际上看到有 log
的复杂度,我们就应该想到要使用分治法,但这题要怎么使用二分法来把两个有序数组合并起来并找到中位数呢?其实,我们并不需要非得这样做,从中位数的定义入手,我们可以知道,如果数组中的一个数在把数组分成长度相等的两部分,且一部分的数值总大于等于另一部分,那这个数就是中位数,即:
把数组nums1
分成两个部分:
nums1[0],nums1[1]...nums1[i - 1] | nums1[i], nums1[i + 1]...nums[m - 1]
左边部分数目为i
,右边部分数目为m - i
且 max(left) <= min(right)
当i = m - i
时(nums1[i - 1] + nums1[i]) / 2
就是中位数
同理,我们无需把两个数组合并起来排序再找中位数,只需要把它们分成两个长度相等的部分,并使max(left) <= min(right)
就可以找到中位数了,即:
nums1[0]...nums1[i - 1] | nums1[i]...nums1[m - 1]
nums2[0]...nums2[j - 1] | nums2[j]...nums2[n - 1]
-
长度相等即:
i + j = m - i + n - j
,当总长度为奇数时默认左边会比右边少一个(由于整数除法的缘故,当n
为奇数时,n / 2 == (n - 1) / 2
)。
如果默认总数为奇数时左边多一个,则i + j = m - i + n - j + 1
。 -
max(left) <= min(right)
即:nums1[i - 1] <= nums2[j] && nums2[j - 1] <= nums1[i]
。
那么现在我们的问题就变成了找到这样的 i和j
来满足上面两个条件。
i
的范围是 [0, m]
(注意不是[0,m-1]
,因为i
是右半部分的第一个下标),由第一个条件可以知道 j = (m + n) / 2 - i
,当 m <= n
时 0 <= j <= n
,否则可能为负数,所以需要注意保证m <= n
。
现在我们就可以用二分法来解决这个问题了,伪代码如下:
1.imin = 0, imax = m
2.j = (m + n) / 2 - i
3.
if nums1[i - 1] > nums2[j] imax = i - 1 //i较大,因此需要减小遍历i的范围
else if nums2[j - 1] > nums1[i] imin = i + 1 //i较小,因此需要增大遍历i的范围
else //找到了合适的i
if (m + n) % 2 == 0 return (max(left) + min(right)) / 2.0
else return min(right)
细节: 减小范围时,应该用i-1
和i+1
,因为i
的范围是imin,imax
,而当前的i
已经被验证过不可能是答案了。
此外,我们需要考虑临界问题:i = 0, i = m, j = 0, j = n
时怎么办?访问nums1[i - 1],nums2[j],nums2[j - 1],nums1[i]
是可能越界的。
很简单,当i == 0
时nums[i - 1]
取INT_MIN
,当i == m
时取INT_MAX
,因为左边要取最大值,右边要取最小值,所以设置这两个极值对结果不会有影响。同理对j
进行设置。
代码
class Solution {
public:
double findMedianSortedArrays(vector<int>& nums1, vector<int>& nums2) {
int m = nums1.size();
int n = nums2.size();
if (m > n) return findMedianSortedArrays(nums2, nums1);
int imin = 0, imax = m;
int max_left = 0, min_right = 0;
while (imin <= imax) {
int i = (imin + imax) / 2;
int j = (m + n) / 2 - i;
int i_left = (i == 0 ? INT_MIN : nums1[i - 1]);
int j_left = (j == 0 ? INT_MIN : nums2[j - 1]);
int i_right = (i == m ? INT_MAX : nums1[i]);
int j_right = (j == n ? INT_MAX : nums2[j]);
if (i_left > j_right) imax = i - 1;
else if (j_left > i_right) imin = i + 1;
else {
max_left = max(i_left, j_left);
min_right = min(i_right, j_right);
break;
}
}
return (m + n) % 2 == 0 ? (max_left + min_right) / 2.0 : min_right;
}
};
复杂度分析
我们只在长度更小的那个数组上进行了二分查找,每次都使得要查找的范围减半,因此时间复杂度为 O ( l o g ( m i n ( m , n ) ) ) O(log(min(m,n))) O(log(min(m,n))),空间复杂度为 O ( 1 ) O(1) O(1)。
解法二
与分析中想到的思路类似,找中位数,我们就是要找到第mid个元素,如何通过二分法来节省遍历的次数呢?
假设我们要找的是第k个元素,那么对nums1[k / 2 - 1]
和nums2[k / 2 - 1]
进行判断:
nums1[k / 2 - 1] < nums2[k / 2 - 1]
:这说明比nums1[k / 2 - 1]
小的元素最多有(k / 2 - 1) * 2 = k - 2
个,所以它最大也只是第k-1个元素。我们把nums1[0] ~ nums1[k / 2 - 1]
排除掉,相应地k也要减去排除掉的元素个数(k -= k/2
)。nums1[k / 2 - 1] > nums2[k / 2 - 1]
:与1相反的情况,思路相同。nums1[k / 2 - 1] == nums2[k / 2 - 1]
:可以纳入情况1处理。
需要注意的是,在这种思路中,第k个元素的下标是k-1。
有几种特殊情况需要处理:
- 当
nums1
为空时,第k个元素就是nums2[k - 1]
- 当
nums1.size() < (k / 2 - 1)
时,取nums1
的最后一个元素来进行比较 - 当
k == 1
时,直接返回min(nums1[0], nums2[0])
代码
class Solution {
public:
double findMedianSortedArrays(vector<int>& nums1, vector<int>& nums2) {
int m = nums1.size(), n = nums2.size();
if ((m + n) % 2 == 0)
return (findKElement((m + n) / 2 + 1, nums1, nums2) + findKElement((m + n) / 2, nums1, nums2)) / 2;
else
return findKElement((m + n + 1) / 2, nums1, nums2);
}
double findKElement(int k, vector<int> nums1, vector<int> nums2){
while (k > 1 && !nums1.empty() && !nums2.empty()) {
int a1 = min(k / 2 - 1, (int)nums1.size() - 1);
int a2 = min(k / 2 - 1, (int)nums2.size() - 1);
if (nums1[a1] <= nums2[a2]) {
nums1.erase(nums1.begin(), nums1.begin() + a1 + 1);
k -= (a1 + 1);
}
else {
nums2.erase(nums2.begin(), nums2.begin() + a2 + 1);
k -= (a2 + 1);
}
}
if (nums1.empty()) return (double)nums2[k - 1];
if (nums2.empty()) return (double)nums1[k - 1];
return (double)min(nums1[0], nums2[0]);
}
};
复杂度分析
按照这种解法,每次都至少能缩小k / 2
的范围,这种解法就是题干要求的解法,时间复杂度为
O
(
l
o
g
(
m
+
n
)
)
O(log(m+n))
O(log(m+n)),空间复杂度为
O
(
1
)
O(1)
O(1)(注意由于代码的实现不完美,复制了两个数组导致空间复杂度为
O
(
m
+
n
)
O(m+n)
O(m+n),其实可以通过维护index来实现数组元素的移除)。