题目
给定两个大小分别为 m
和 n
的正序(从小到大)数组 nums1
和 nums2
。请你找出并返回这两个正序数组的 中位数 。
示例 1:
输入:nums1 = [1,3], nums2 = [2]
输出:2.00000
解释:合并数组 = [1,2,3] ,中位数 2
示例 2:
输入:nums1 = [1,2], nums2 = [3,4]
输出:2.50000
解释:合并数组 = [1,2,3,4] ,中位数 (2 + 3) / 2 = 2.5
示例 3:
输入:nums1 = [0,0], nums2 = [0,0]
输出:0.00000
示例 4:
输入:nums1 = [], nums2 = [1]
输出:1.00000
示例 5:
输入:nums1 = [2], nums2 = []
输出:2.00000
提示:
- n u m s 1. l e n g t h = = m nums1.length == m nums1.length==m
- n u m s 2. l e n g t h = = n nums2.length == n nums2.length==n
- 0 < = m < = 1000 0 <= m <= 1000 0<=m<=1000
- 0 < = n < = 1000 0 <= n <= 1000 0<=n<=1000
- 1 < = m + n < = 2000 1 <= m + n <= 2000 1<=m+n<=2000
- − 1 0 6 < = n u m s 1 [ i ] , n u m s 2 [ i ] < = 1 0 6 -10^6 <= nums1[i], nums2[i] <= 10^6 −106<=nums1[i],nums2[i]<=106
进阶:你能设计一个时间复杂度为 O(log (m+n))
的算法解决此问题吗?
分析
1.常规思想
先合并得到一个更大的有序数组,大的有序数组中间的位置的元素就是中位数。
时间复杂度:
O
(
m
+
n
)
O(m+n)
O(m+n)
空间复杂度:
O
(
m
+
n
)
O(m+n)
O(m+n)
2.常规思想优化
【优化】
方案就是假合并,并不需要真正将两个数组进行合并,只需要找到中位数的位置即可。
【思路】
已知两个数组的长度,所以中位数对应的两个数组的下标之和也是已知的。维护两个指针,初始分别指向两个数组的下标 0 位置处,每次将指向较小值的指针后移一位(如果一个指针已经到达数组末尾,则只需要移动另一个数组的指针), 直到到达中位数的位置。)
时间复杂度:
O
(
m
+
n
)
O(m+n)
O(m+n)
空间复杂度:
O
(
1
)
O(1)
O(1)
注意:在代码实现上,不仅要考虑奇偶问题,还要考虑一个数组遍历结束后的 各种边界问题。
此方法的优化点是将奇偶两种情况合并到一起。具体思想为如下:
- 如果是 奇数,只需要知道第 l e n + 1 2 \frac{len+1}{2} 2len+1 个元素即可,需要遍历的次数为 l e n 2 + 1 \frac{len}{2} + 1 2len+1 次;
- 如果是 偶数,需要知道第 l e n 2 \frac{len}{2} 2len 个和 l e n 2 + 1 \frac{len}{2} + 1 2len+1 个元素,需要遍历的次数也是 l e n 2 + 1 \frac{len}{2} + 1 2len+1次。
- 返回中位数时,奇数只需要最后一次的遍历结果即可;偶数需要最后一次和倒数第二次的遍历结果。
【代码】
class Solution {
public:
double findMedianSortedArrays(vector<int>& nums1, vector<int>& nums2) {
int prev, now;
int n = nums1.size(), m = nums2.size();
int times = (n + m) / 2 + 1; //遍历次数
int i = 0, j = 0;
while (i + j < times) {
prev = now; //记录上一次遍历得到结果
if (j >= m || (i < n && nums1[i] < nums2[j])) {
now = nums1[i++];
} else {
now = nums2[j++];
}
}
if ((n + m) % 2) return now; //如果长度为奇数,就是最后一次遍历得到的结果
return (prev + now)/2.0;//如果长度为偶数,就是最后一次和上一次遍历得到的结果除以2
}
};
3.寻找第 k 小数 : 二分查找
【转化】
根据中位数的定义,当
m
+
n
m + n
m+n 为奇数时,中位数是两个有序数组中的第
(
m
+
n
)
/
2
(m+n)/2
(m+n)/2 个元素;当
m
+
n
m + n
m+n 为 偶数时,中位数是两个有序数组中的第
(
m
+
n
)
/
2
(m + n)/2
(m+n)/2 和
(
m
+
n
)
/
2
+
1
(m+n)/2 + 1
(m+n)/2+1 个元素的平均值。
因此,本题可以转化为 寻找两个有序数组中的第
k
k
k 小的数,其中
k
k
k 为
(
m
+
n
)
/
2
(m+n)/2
(m+n)/2 或
(
m
+
n
)
/
2
+
1
(m+n)/2 + 1
(m+n)/2+1。
「第 k k k 小数」 的思想主要就是根据两个数的三种比较结果,不断去除不满足的元素。
【三种情况】
- 如果
A[k/2 - 1] < B[k/2 - 1]
,则比 A [ k / 2 − 1 ] A[k/2 - 1] A[k/2−1] 小的数 最多 只有 A A A 的前 k / 2 − 1 k/2 - 1 k/2−1 个数和 B B B的前 k / 2 − 1 k/2-1 k/2−1 个数,即比 A [ k / 2 − 1 ] A[k/2 - 1] A[k/2−1] 小的数最多只有 k − 2 k - 2 k−2 个,因此 A [ k / 2 − 1 ] A[k/2-1] A[k/2−1] 不可能是第 k k k 个数,A[0] ~ A[k/2-1] 也都不可能是第 k k k 个数,可以 全部排除。 - 如果
A[k/2 - 1] > B[k/2 - 1]
,则可以排除B[0] ~B[k/2 - 1]。 - 如果
A[k/2 - 1] == B[k/2 - 1]
,则可以 归入第一种情况 处理。
【处理结果】
比较 A[k/2 - 1]
和 B[k/2 - 1]
之后,可以 排除
k
/
2
k/2
k/2个 不可能是第
k
k
k 小的数,在查找范围缩小了一半。同时,在排除后的 新数组上进行二分查找,并且根据 排除数的个数,减小
k
k
k 的值,因为排除的数都不大于第
k
k
k 小的数。
【三种特殊情况】
- 如果
A[k/2 - 1]
或者B[k/2 - 1]
越界,那么可以选取对应数组中的最后一个元素。在这种情况下,必须根据排除数的个数减小 k k k 的值,而不能直接将 k k k 减去 k / 2 k/2 k/2。 - 如果 一个数组为空,说明该数组中的所有元素都被排除,可以直接 返回另一个数组中第 k k k 小的元素。
- 如果
k = 1
,只需要返回两个数组 首元素的最小值 即可。
【代码】
class Solution {
public:
int getKthElement(vector<int> &nums1, vector<int> &nums2, int k) {
int ind1 = 0, ind2 = 0;
int len1 = nums1.size(), len2 = nums2.size();
while (true) {
//特殊情况
if (ind1 == len1) return nums2[ind2 + k - 1]; //nums1数组为空
if (ind2 == len2) return nums1[ind1 + k - 1]; //nums2数组为空
if (k == 1) return min(nums1[ind1], nums2[ind2]); //k = 1
//正常情况
//ind1和ind2作为起始点,newInd1和newInd2 作为比较点在不断更新
int half = k / 2;
int newInd1 = min(ind1 + half, len1) - 1; //发生越界,记录需要比较的位置
int newInd2 = min(ind2 + half, len2) - 1; //发生越界,记录需要比较的位置
int val1 = nums1[newInd1], val2 = nums2[newInd2]; //获取两个需要比较的位置
if (val1 <= val2) {
k -= (newInd1 - ind1 + 1); //去除掉不符合要求的那部分数据
ind1 = newInd1 + 1; //连同比较位置一起删除,新的开始位置是 比较位置 的后一位
} else {
k -= (newInd2 - ind2 + 1);
ind2 = newInd2 + 1;
}
}
}
double findMedianSortedArrays(vector<int>& nums1, vector<int>& nums2) {
int n = nums1.size(), m = nums2.size();
int totalLength = n + m;
if (totalLength % 2) {
int midIndex = totalLength / 2;
return getKthElement(nums1, nums2, midIndex + 1);
} else {
int midIndex1 = totalLength / 2 - 1, midIndex2 = totalLength / 2;
return (getKthElement(nums1, nums2, midIndex1 + 1) + getKthElement(nums1, nums2, midIndex2 + 1)) / 2.0;
}
}
};
【复杂度】
- 时间复杂度: O ( l o g ( m + n ) ) O(log(m + n)) O(log(m+n)),其中 m m m 和 n n n 分别是数组 n u m s 1 nums1 nums1 和 n u m s 2 nums2 nums2 的长度,初始时有 k = ( m + n ) / 2 k = (m + n) / 2 k=(m+n)/2 或 k = ( m + n ) / 2 + 1 k = (m + n) / 2 + 1 k=(m+n)/2+1,每轮循环可以将查找范围缩小一半,因此时间复杂度为 O ( l o g ( m + n ) ) O(log(m+n)) O(log(m+n))。
- 空间复杂度: O ( 1 ) O(1) O(1)
【另一种分析】参考自LeetCode-4 寻找两个有序数组的中位数
一个有序数组的中位数,当有序数组的个数为奇数时,如 nums=[1, 2, 3, 4, 5]
,该数组的中位数为nums[2]=3
;当有序数组的个数为偶数时,如 nums=[1, 2, 3, 4, 5, 6]
,该数组的中位数为(nums[2]+nums[3])/2=3.5
。
如下图所示,用同一公式可求出任意个数有序数组的中位数。
对于两个有序数组来说,只要找出第
(
m
+
n
+
1
)
/
2
(m+n+1)/2
(m+n+1)/2 大的数和第
(
m
+
n
+
2
)
/
2
(m+n+2)/2
(m+n+2)/2 的数,然后求平均数即可。此处的
m
m
m 和
n
n
n 分别指两个数组的大小,
m
+
n
m + n
m+n 如图中的 nums.length
,第
(
m
+
n
+
1
)
/
2
(m+n+1)/2
(m+n+1)/2 大的数是指假设这两个数组合成一个有序数组后找出第
(
m
+
n
+
1
)
/
2
(m+ n+ 1) / 2
(m+n+1)/2 大的数(此处不像图中进行减1,因为这里说的第几大的数是从下标 1 开始的;图中减 1 是因为使用的数组,下标从 0 开始)。
接下来在两个有序数组中找到第
(
m
+
n
+
1
)
/
2
(m+n+1)/2
(m+n+1)/2 大的数和 第
(
m
+
n
+
2
)
/
2
(m+n+2)/2
(m+n+2)/2 大的数,抽象后可表述为在两个有序数组中找第 k 大的数。因为进阶中要求的时间复杂度为
O
(
l
o
g
(
m
+
n
)
)
O(log(m+ n))
O(log(m+n)),可以想到 二分查找。
查找时需要考虑一些特殊情况:
- 当某个数组查找的起始位置大于等于该数组长度时,说明这个数组中的所有数已经被淘汰,则只需要在另一个数组中查找即可。
- 如果 k = 1 k = 1 k=1 时,即需要查找第一个数,则找到两个数组起始位置中最小的那个即可。
处理完特殊情况后,分析一般情况:此处说的二分是指对数组的大小进行二分还是对 k k k 进行二分? 以前对一维数组进行二分查找时,一般都是对数组的大小进行二分,而这里需要 对 k k k 进行二分。意思是,需要在两个数组查找第 k / 2 k / 2 k/2 大的数,由于这两个数组的长度不定,有可能存在有一个数组中没有第 k / 2 k/2 k/2 大的数,如果没有则赋值为整型最大值。
【代码】
//递归实现
class Solution {
public:
// 在两个有序数组中二分查找第k大元素
int getKthElement(vector<int> &nums1, int start1, vector<int> &nums2, int start2, int k) {
int m = nums1.size(), n = nums2.size();
//特殊情况
if (start1 > m - 1) return nums2[start2 + k - 1];
if (start2 > n - 1) return nums1[start1 + k - 1];
if (k == 1) return min(nums1[start1], nums2[start2]);
// 分别在两个数组中查找第k/2个元素,若存在(即数组没有越界),标记为找到的值;若不存在,标记为整数最大值
int nums1mid = start1 + k / 2 - 1 < nums1.size() ? nums1[start1 + k / 2 - 1] : INT_MAX;
int nums2mid = start2 + k / 2 - 1 < nums2.size() ? nums2[start2 + k / 2 - 1] : INT_MAX;
// 确定最终的第k/2个元素,然后递归查找
if (nums1mid < nums2mid) {
return getKthElement(nums1, start1 + k / 2, nums2, start2, k - k / 2);
} else {
return getKthElement(nums1, start1, nums2, start2 + k / 2, k - k / 2);
}
}
double findMedianSortedArrays(vector<int>& nums1, vector<int>& nums2) {
int m = nums1.size(), n = nums2.size();
int l = (m + n + 1) / 2;
int r = (m + n + 2) / 2;
return (getKthElement(nums1, 0, nums2, 0, l) + getKthElement(nums1, 0, nums2, 0, r)) / 2.0;
}
};
【代码】
//循环实现:时间复杂度O(log(m+n)),空间复杂度O(1)
class Solution {
public:
int getKthElement(vector<int> &nums1, int start1, vector<int> &nums2, int start2, int k) {
while (k != 1) {
if (start1 > (int)(nums1.size() - 1)) return nums2[start2 + k - 1];
if (start2 > (int)(nums2.size() - 1)) return nums1[start1 + k - 1];
int nums1mid = start1 + k / 2 - 1 < nums1.size() ? nums1[start1 + k / 2 - 1] : INT_MAX;
int nums2mid = start2 + k / 2 - 1 < nums2.size() ? nums2[start2 + k / 2 - 1] : INT_MAX;
if (nums1mid < nums2mid) start1 += k / 2;
else start2 += k / 2;
k -= k / 2;
}
//判断是否越界
if (start1 > (int)(nums1.size() - 1)) return nums2[start2 + k - 1];
if (start2 > (int)(nums2.size() - 1)) return nums1[start1 + k - 1];
return min(nums1[start1], nums2[start2]);
}
double findMedianSortedArrays(vector<int>& nums1, vector<int>& nums2) {
int m = nums1.size(), n = nums2.size();
int l = (m + n + 1) / 2;
int r = (m + n + 2) / 2;
return (getKthElement(nums1, 0, nums2, 0, l) + getKthElement(nums1, 0, nums2, 0, r)) / 2.0;
}
};
4. 划分数组
划分数组的思想的时间复杂度比二分查找还低,二分查找的思想每轮循环可以将查找范围减小一般,因此时间复杂度为 O ( l o g ( m + n ) ) O(log(m + n)) O(log(m+n)),但是划分数组的思想可以确定对较短的数组进行二分查找,所以它的时间复杂度是 O ( l o g m i n ( m , n ) ) O(log \ min(m, n)) O(log min(m,n))。
需要明白中位数的作用:将一个集合划分为两个长度相等的子集,其中一个子集中的元素总是大于另一个子集中的元素,这种思想无论是在几个数组中都是适用的,于是衍生出了下面的算法思想。
首先,讨论奇偶的两种不同情况下的不同划分方式:
-
当 A 和 B 的总长度为 偶数 时,如果可以确定:
- l e n ( l e f t _ p a r t ) = l e n ( r i g h t _ p a r t ) len(left\_part) = len(right\_part) len(left_part)=len(right_part)
- m a x ( l e f t _ p a r t ) < = m i n ( r i g h t _ p a r t ) max(left\_part) <= min(right\_part) max(left_part)<=min(right_part)
- 那么
{A, B}
中的所有元素已经被划分为相同长度的两个部分,且前一部分中的元素总是小于或等于后一部分中的元素。中位数就是前一部分的最大值和后一部分的最小值的平均值。
-
当 A 和 B 的总长度为 奇数 时,如果可以确定:
- l e n ( l e f t _ p a r t ) = l e n ( r i g h t _ p a r t ) + 1 len(left\_part) = len(right\_part) + 1 len(left_part)=len(right_part)+1
- m a x ( l e f t _ p a r t ) < = m i n ( r i g h t _ p a r t ) max(left\_part) <= min(right\_part) max(left_part)<=min(right_part)
- 那么,
{A,B}
中的所有元素已经被划分为两部分:前一部分比后一部分多一个元素,且前一个部分中元素总是小于或等于后一部分中的元素。中位数就是前一部分的最大值。
然后,编写代码时,由于计算机的取整操作,可以将这两种情况合并成一种代码书写方式。其中 i i i 和 j j j 分别为链各个数组的划分位置。
- 当 m + n m + n m+n 为偶数: i + j = m − i + n − j i + j = m - i + n - j i+j=m−i+n−j;当 m + n m + n m+n 为奇数: i + j = m − i + n − j + 1 i + j = m - i + n - j + 1 i+j=m−i+n−j+1。等号左侧为前一部分的元素个数,右侧为后一部分的元素个数。等式变形得到 i + j = ( m + n + 1 ) / 2 i + j = (m + n+1) / 2 i+j=(m+n+1)/2,这里的分数结果只保留整数部分。
- 规定 A 的长度小于等于 B 的长度,即 m < = n m<=n m<=n。这样对于 任意的 i ∈ [ 0 , m ] i ∈ [0, m] i∈[0,m] ,都有 ( j = ( m + n + 1 ) / 2 − i ) ∈ [ 0 , n ] (j = (m + n + 1) / 2 - i) ∈ [0, n] (j=(m+n+1)/2−i)∈[0,n]。如果 A 的长度较长,只需要 交换 A 和 B 即可。如果 m > n m > n m>n,那么得出的 j j j 可能是负数。
- B [ j − 1 ] < = A [ i ] B[j - 1] <= A[i] B[j−1]<=A[i] 以及 A [ i − 1 ] < = B [ j ] A[i - 1] <= B[j] A[i−1]<=B[j] ,即前一部分的最大值小于等于后一部分的最小值
接着,处理边界问题:
- 假设
A[i - 1], B[j - 1], A[i], B[j]
总是存在。对于 i = 0 、 i = m 、 j = 0 、 j = n i = 0、i = m、j = 0、j = n i=0、i=m、j=0、j=n 这样的临界条件,只需要规定A[-1] = B[-1] = -∞
,A[m] = B[n] = +∞
即可。这也是比较直观的:
- 当一个数组不出现在前一部分时,对应的值为 负无穷,就不会对前一部分的最大值产生影响;
- 当一个数组不出现在后一部分时,对应的值为 正无穷, 就不会对后一部分的最小值产生影响。
最后,编码。需要对两个条件进行判断: B [ j − 1 ] < = A [ i ] B[j - 1] <= A[i] B[j−1]<=A[i] 以及 A [ i − 1 ] < = B [ j ] A[i - 1] <= B[j] A[i−1]<=B[j]。这两种情况是可以等价转换的,需要一个条件的判断即可。
- 在 [ 0 , m ] [0, m] [0,m] 中找到 i i i,使 B [ j − 1 ] < = A [ i ] 且 A [ i − 1 ] < = B [ j ] B[j - 1] <= A[i] 且 A[i - 1] <= B[j] B[j−1]<=A[i]且A[i−1]<=B[j],其中 j = ( m + n + 1 ) / 2 − i j = (m + n + 1)/2 - i j=(m+n+1)/2−i 等价于 [ 0 , m ] [0,m] [0,m] 中找到 i i i,使 A [ i − 1 ] < = B [ j ] A[i - 1] <= B[j] A[i−1]<=B[j]
- 当 i i i 从 0 ~ m 递增时,A[i - 1] 递增,B[j] 递减,所以 一定存在一个最大的 i i i 满足 A [ i − 1 ] < = B [ j ] A[i - 1] <= B[j] A[i−1]<=B[j]
- 如果 i i i 是最大的,那么说明 i + 1 i + 1 i+1 不满足。将 i + 1 i+1 i+1 代入可以得到 A [ i ] > B [ j − 1 ] A[i] > B[j - 1] A[i]>B[j−1],也就是 B [ j − 1 ] < A [ i ] B[j - 1] < A[i] B[j−1]<A[i],就和进行等价变换前的 i i i 的性质一致了,甚至还要更强。
- 因此可以对 i i i 在区间[0, m] 上进行 二分搜索,找到最大的满足 A [ i − 1 ] < = B [ j ] A[i - 1] <= B[j] A[i−1]<=B[j] 的 i i i 值,就得到了划分的方法。
【代码】
//官方题解
class Solution {
public:
double findMedianSortedArrays(vector<int>& nums1, vector<int>& nums2) {
//确保nums1的长度 <= nums2的长度
if (nums1.size() > nums2.size()) return findMedianSortedArrays(nums2, nums1);
int m = nums1.size(), n = nums2.size();
int left = 0, right = m;
int maxVal = 0, minVal = 0; //前一部分的最大值和后一部分的最小值
while (left <= right) {
// 前一部分包含 nums1[0 .. i-1] 和 nums2[0 .. j-1]
// 后一部分包含 nums1[i .. m-1] 和 nums2[j .. n-1]
int i = (left + right) / 2;
int j = (m + n + 1) / 2 - i;
//nums_im1, nums_i, nums_jm1, nums_j 分别表示 nums1[i-1], nums1[i], nums2[j-1], nums2[j]
//当一个数组不出现在前一部分时,对应的值为负无穷,就不会对前一部分的最大值产生影响
int nums_im1 = (i == 0 ? INT_MIN : nums1[i - 1]);
//当一个数组不出现在后一部分时,对应的值为正无穷,就不会对后一部分的最小值产生影响
int nums_i = (i == m ? INT_MAX : nums1[i]);
int nums_jm1 = (j == 0 ? INT_MIN : nums2[j - 1]);
int nums_j = (j == n ? INT_MAX : nums2[j]);
if (nums_im1 <= nums_j) {
maxVal = max(nums_im1, nums_jm1);
minVal = min(nums_i, nums_j);
left = i + 1;
} else {
right = i - 1;
}
}
return (m + n) % 2 == 0 ? (maxVal + minVal) / 2.0 : maxVal;
}
};
【复杂度分析】
- 时间复杂度: O ( log min ( m , n ) ) ) O(\log\min(m,n))) O(logmin(m,n))),其中 m m m 和 n n n 分别是数组 nums 1 \textit{nums}_1 nums1和 nums 2 \textit{nums}_2 nums2的长度。查找的区间是 [ 0 , m ] [0, m] [0,m],而该区间的长度在每次循环之后都会减少为原来的一半。所以,只需要执行 log m \log m logm 次循环。由于每次循环中的操作次数是常数,所以时间复杂度为 O ( log m ) O(\log m) O(logm)。由于我们可能需要交换 nums 1 \textit{nums}_1 nums1 和 nums 2 \textit{nums}_2 nums2 使得 m ≤ n m \leq n m≤n,因此时间复杂度是 O ( log min ( m , n ) ) ) O(\log\min(m,n))) O(logmin(m,n))) 。
- 空间复杂度: O ( 1 ) O(1) O(1)。