题目来自LeetCode,链接:寻找两个有序数组的中位数。给定两个大小为 m 和 n 的有序数组 nums1 和 nums2。请你找出这两个有序数组的中位数,并且要求算法的时间复杂度为 O(log(m + n))。你可以假设 nums1 和 nums2 不会同时为空。
示例1:
nums1 = [1, 3]
nums2 = [2]
则中位数是 2.0
示例2:
nums1 = [1, 2]
nums2 = [3, 4]
则中位数是 (2 + 3)/2 = 2.5
首先我们先不管时间复杂度,用最朴素的方法就是直接找到两个数组中位于中间的数,这很容易做到,就是利用两个索引分别指向两个数组,每次都增大数值较小的那个数组的索引,直到找到中位数。这种方法的时间复杂度为 O ( m + n ) O(m+n) O(m+n),空间复杂度为 O ( 1 ) O(1) O(1)。
JAVA版代码如下:
class Solution {
public double findMedianSortedArrays(int[] nums1, int[] nums2) {
int m = nums1.length;
int n = nums2.length;
if (m == 0) {
return n % 2 == 0 ? (nums2[n/2] + nums2[n/2 - 1]) / 2.0 : nums2[n/2];
}
if (n == 0) {
return m % 2 == 0 ? (nums1[m/2] + nums1[m/2 - 1]) / 2.0 : nums1[m/2];
}
int idx1 = 0, idx2 = 0;
int count = 0;
int middle = 0;
while (count * 2 < m + n) {
if (idx2 >= n || (idx1 < m && nums1[idx1] < nums2[idx2])) {
middle = nums1[idx1++];
}
else {
middle = nums2[idx2++];
}
++count;
}
if ((m + n) % 2 == 1) {
return middle * 1.0;
}
else {
int temp = (idx2 >= n || (idx1 < m && nums1[idx1] < nums2[idx2])) ? nums1[idx1] : nums2[idx2];
return (middle + temp) / 2.0;
}
}
}
提交结果如下:
接着再看怎么进一步减少时间复杂度,其实一看到要把复杂度降到log,就应该知道跟二分法有关了。先讲一种比较好懂的方法,是从评论区看到的。令k=(m+n+1)/2
,那么找中位数其实就等价于找第k小的数(m+n为奇数的情况)或第k小和第k+1小的两数的平均数(m+n为偶数的情况)。具体的做法就是选择nums1第k/2
个元素(记为n1
)和nums2的第k/2
个元素(记为n2
)相比,如果有n1<n2
,就说明nums1前k/2
个元素一定小于中位数,所以可以把这部分从nums1中删去(当然不用真的删除,用一个索引说明就行),然后就变为在删除了部分数的nums1和不变的nums2中寻找第k-k/2
小的数了。如果是n2<=n1
的情况,就反过来就好了。因为每次k
的值都会减少一半,所以时间复杂度就是
O
(
l
o
g
k
)
O(logk)
O(logk),也就是
O
(
l
o
g
(
m
+
n
)
)
O(log(m+n))
O(log(m+n)),空间复杂度则是
O
(
1
)
O(1)
O(1)。
JAVA版代码如下:
class Solution {
public double findMedianSortedArrays(int[] nums1, int[] nums2) {
int m = nums1.length, n = nums2.length;
if (m == 0) {
return n % 2 == 0 ? (nums2[n/2] + nums2[n/2 - 1]) / 2.0 : nums2[n/2];
}
if (n == 0) {
return m % 2 == 0 ? (nums1[m/2] + nums1[m/2 - 1]) / 2.0 : nums1[m/2];
}
int k = (m + n + 1) / 2;
int idx1 = 0, idx2 = 0;
int left1 = -1, left2 = -1;
// 去掉k-1个数后idx1和idx2指向的两个数之中较小者就是第k小的数了
while (k > 1) {
idx1 = Math.min(m - 1, k / 2 + left1);
idx2 = Math.min(n - 1, k - k / 2 + left2);
if (nums1[idx1] < nums2[idx2]) {
// nums1在idx1前面的所有数一定都比中位数小,全部去掉
k -= idx1 - left1;
if (idx1 == m - 1) {
return (m + n) % 2 == 0 ? (nums2[left2 + k] + nums2[left2 + k + 1]) / 2.0 : nums2[left2 + k];
}
left1 = idx1;
if (k <= 1) {
idx1 = left1 + 1;
}
}
else {
// nums2在idx2前面的所有数一定都比中位数小,全部去掉
k -= idx2 - left2;
if (idx2 == n - 1) {
return (m + n) % 2 == 0 ? (nums1[left1 + k] + nums1[left1 + k + 1]) / 2.0 : nums1[left1 + k];
}
left2 = idx2;
if (k <= 1) {
idx2 = left2 + 1;
}
}
}
if ((m + n) % 2 == 0) {
// 偶数的情况下需要找到中间的两个数
int n1, n2;
if (nums1[idx1] < nums2[idx2]) {
n1 = nums1[idx1];
if (idx1 == m - 1) {
n2 = nums2[idx2];
}
else {
n2 = nums1[idx1 + 1] < nums2[idx2] ? nums1[idx1 + 1] : nums2[idx2];
}
}
else {
n1 = nums2[idx2];
if (idx2 == n - 1) {
n2 = nums1[idx1];
}
else {
n2 = nums2[idx2 + 1] < nums1[idx1] ? nums2[idx2 + 1] : nums1[idx1];
}
}
return (n1 + n2) / 2.0;
}
else {
return nums1[idx1] < nums2[idx2] ? nums1[idx1] * 1.0 : nums2[idx2] * 1.0;
}
}
}
提交结果如下:
最后还是从评论区看到的另一种更优但也更复杂的做法。有点类似前面的做法,还是取k=(m+n+1)/2
,然后不失一般性假设nums1
的长度不超过nums2
的(不然就将两个数组换一下),初始化iMin=0
和iMax=m
,然后指定一个索引i=iMin+(iMax-iMin)/2
(注意这里就是二分法)表示将nums1
的前i
个数划分为左半部分,同理指定一个索引j=k-i
表示将nums2
的前j
个数划分为左半部分,这样总的左半部分就有k
个数,当满足nums1[i-1]<nums2[j]
且nums2[j-1]<nums1[i]
的时候就说明左半部分最大的数一定小于右半部分最小的数,从而我们可以知道第k小的数就是max(nums1[i-1], nums2[j-1])
,第k+1
小的数就是min(nums1[i], nums2[j])
(这都是在i,j
未到边界的情况下得到的,到边界的话需要特殊处理一下)。但如果有nums1[i-1]>nums2[j]
说明i
的位置还需要往左移动以满足nums1[i-1]<nums2[j]
(注意i
左移的同时j
会随之右移),这里咋移动呢,就是iMax=i-1
(注意这里相当于二分了,之后i
的位置其实就是原先左半部分的中间位置),同理nums2[j-1]>nums1[i]
的时候需要将i
右移(iMin=i+1
)。此外需要注意的就是边界条件也就是i=0/m
和j=0/n
时候的特殊处理,细心点就好了。因为我们是在长度较小的那个数组上进行二分,所以时间复杂度就是
O
(
l
o
g
(
m
i
n
(
m
,
n
)
)
)
O(log(min(m,n)))
O(log(min(m,n))),空间复杂度还是
O
(
1
)
O(1)
O(1)。
JAVA版代码如下:
class Solution {
public double findMedianSortedArrays(int[] nums1, int[] nums2) {
int m = nums1.length, n = nums2.length;
if (m > n) {
return findMedianSortedArrays(nums2, nums1);
}
if (m == 0) {
return n % 2 == 0 ? (nums2[n/2 - 1] + nums2[n/2]) / 2.0 : nums2[n/2];
}
int middle = (m + n + 1) / 2;
int iMin = 0, iMax = m;
while (iMin <= iMax) {
int i = iMin + (iMax - iMin) / 2;
int j = middle - i;
if (i < m && nums1[i] < nums2[j - 1]) {
// nums1[i]太小了,需要增大i
iMin = i + 1;
}
else if (i > 0 && nums1[i - 1] > nums2[j]) {
// nums1[i - 1]太大了,需要减小i
iMax = i - 1;
}
else {
// 找到合适的i,j或者i,j到了边界
int middleNum, middleNum_1;
if (i == 0) {
middleNum = nums2[j - 1];
}
else if (j == 0) {
middleNum = nums1[i - 1];
}
else {
middleNum = Math.max(nums1[i - 1], nums2[j - 1]);
}
if ((m + n) % 2 == 1) {
return middleNum;
}
if (i == m) {
middleNum_1 = nums2[j];
}
else if (j == n) {
middleNum_1 = nums1[i];
}
else {
middleNum_1 = Math.min(nums1[i], nums2[j]);
}
return (middleNum + middleNum_1) / 2.0;
}
}
return 0.0;
}
}
提交结果如下:
Python版代码如下:
class Solution:
def findMedianSortedArrays(self, nums1: List[int], nums2: List[int]) -> float:
m = len(nums1)
n = len(nums2)
if m > n:
return self.findMedianSortedArrays(nums2, nums1)
if m == 0:
return (nums2[n//2 - 1] + nums2[n//2]) / 2 if n % 2 == 0 else nums2[n//2]
iMin, iMax = 0, m
middle = (m + n + 1) // 2
while iMin <= iMax:
i = iMin + (iMax - iMin) // 2
j = middle - i
if i > 0 and nums1[i - 1] > nums2[j]:
iMax = i - 1
elif i < m and nums1[i] < nums2[j - 1]:
iMin = i + 1
else:
middleNum, middleNum_1 = 0, 0
if i == 0:
middleNum = nums2[j - 1]
elif j == 0:
middleNum = nums1[i - 1]
else:
middleNum = max(nums1[i - 1], nums2[j - 1])
if (m + n) % 2 == 1:
return middleNum
if i == m:
middleNum_1 = nums2[j]
elif j == n:
middleNum_1 = nums1[i]
else:
middleNum_1 = min(nums1[i], nums2[j])
return (middleNum + middleNum_1) / 2
return 0.0
提交结果如下: