题目描述
给定两个大小为 m
和n
的有序数组 nums1
和 nums2
。
请你找出这两个有序数组的中位数,并且要求算法的时间复杂度为
O
(
l
o
g
(
m
+
n
)
)
O(log(m + n))
O(log(m+n))。
你可以假设 nums1
和nums2
不会同时为空。
示例 1:
nums1 = [1, 3]
nums2 = [2]
则中位数是 2.0
示例 2:
nums1 = [1, 2]
nums2 = [3, 4]
则中位数是 (2 + 3)/2 = 2.5
解题思考
其实我一开始想的是利用桶排来做这道题。
简单来说就是建立一个数组int p[MAX] = {0}
,将nums1
和nums2
两个数组中的数据作为下标,每次下标x
出现,那么p[x]++
,并直接对大于0
的元素进行求和,直到累加值
c
o
u
n
t
=
(
m
+
n
)
/
2
+
1
count = (m+n)/2+1
count=(m+n)/2+1(奇数)或者
{
c
o
u
n
t
=
(
m
+
n
)
/
2
,
c
o
u
n
t
=
(
m
+
n
)
/
2
+
1
}
\{count=(m+n)/2,\ count=(m+n)/2+1\}
{count=(m+n)/2, count=(m+n)/2+1}(偶数)。
这样,时间复杂度就为 O ( ( m + n ) / 2 ) O((m+n)/2) O((m+n)/2)。看起来上面的解法很不错,达到了线性的时间,但是题目要求进一步压缩时间复杂度,降至 O ( l o g ( m + n ) ) O(log(m + n)) O(log(m+n))这下问题就有意思多了。什么算法能够产生 l o g log log的时间复杂度?我的第一反应是:分治策略。
当然,分治法还有一些逻辑细节需要理清楚。下面我来介绍一下思路:
- 首先,我们知道
nums1
和nums2
两个数组都是单调递增的,那么:
i f n u m s 1 [ x ] ≤ n u m s 2 [ y ] t h e n n u m s 1 [ 0 ] < . . . < n u m s [ x ] ≤ n u m s 2 [ y ] if \ \ nums1[x] \le nums2[y] \\ then \ \ nums1[0]<...<nums[x] \le nums2[y] if nums1[x]≤nums2[y]then nums1[0]<...<nums[x]≤nums2[y] - 已知我们的中位数下标
k
=
(
m
+
n
)
/
2
k = (m+n)/2
k=(m+n)/2,则:
i ⇐ ( k − 1 ) / 2 i f n u m s 1 [ i ] < n u m s 2 [ i ] t h e n e r a s e ( n u m s 1 [ 0 ] , . . . , n u m s 1 [ i ] ) e l s e i f n u m s 2 [ i ] < n u m s 1 [ i ] t h e n e r a s e ( n u m s 2 [ 0 ] , . . . , n u m s 2 [ i ] ) k = k − i i \Leftarrow (k - 1)/2\\ if \ \ nums1[i]<nums2[i] \\ then \ erase(nums1[0] ,\ ... \ , nums1[i]) \\ else \ if \ nums2[i]<nums1[i] \\ then \ erase(nums2[0] ,\ ... \ , nums2[i]) \\ k = k- i i⇐(k−1)/2if nums1[i]<nums2[i]then erase(nums1[0], ... ,nums1[i])else if nums2[i]<nums1[i]then erase(nums2[0], ... ,nums2[i])k=k−i - 当然,上述公式中没有考虑编程语言中的整型,浮点型的区别,也没有考虑
m
+
n
m+n
m+n奇数和偶数的区别。因此,仅仅是个思路而已。我解释一下其中的含义:
- 因为我们要寻找第
k
k
k个数,那么我们可以先找到两个数组中的第
k
/
2
k/2
k/2的数,比较他们的大小。如果
n
u
m
s
1
[
k
/
2
]
<
n
u
m
s
2
[
k
/
2
]
nums1[k/2]<nums2[k/2]
nums1[k/2]<nums2[k/2]由于单调性,我们可以确定
nums1[k/2]
及其之前的所有数,都一定小于中位数。那么这些数对我来说就没有用了,因此我将其从数组中移除。此时,我们只需要找到新的两个串第 k − k / 2 k-k/2 k−k/2小的数,就可以了。 - 接下来就是迭代或者递归了,到最后, k = 1 k = 1 k=1,意味着我们只需要取到第1小的数,这个数就是中位数。
- 当然,在真实做题的时候,我们需要考虑奇偶。
- 因为我们要寻找第
k
k
k个数,那么我们可以先找到两个数组中的第
k
/
2
k/2
k/2的数,比较他们的大小。如果
n
u
m
s
1
[
k
/
2
]
<
n
u
m
s
2
[
k
/
2
]
nums1[k/2]<nums2[k/2]
nums1[k/2]<nums2[k/2]由于单调性,我们可以确定
接下来是解题步骤:
- 首先,中位数是处在一串数字正中间的数。如果数字串长度和为奇数,那么 k = ( m + n ) / 2 + 1 k=(m+n)/2+1 k=(m+n)/2+1;如果字符串长度为偶数,则是 { k = ( m + n ) / 2 , k = ( m + n ) / 2 + 1 } \{k=(m+n)/2,\ k=(m+n)/2+1\} {k=(m+n)/2, k=(m+n)/2+1},接下来分类讨论,有下面的几种情况,我就不细说了。
- 首先是处理两个数组:
i f 有 单 个 串 为 空 : t h e n : i f 奇 数 : r e t u r n n u m s [ ( m + n ) / 2 ] e l s e i f 偶 数 : r e t u r n ( n u m s [ ( m + n ) / 2 − 1 ] + n u m s [ ( m + n ) / 2 ] ) / 2 i f 两 个 串 不 为 空 : t h e n : w h i l e k ! = 1 : i ⇐ ( k − 1 ) / 2 − 1 i f n u m s 1. s i z e < i : i f n u m s 1. b a c k < n u m s 2 [ i ] : n u m s 1. c l e a r b r e a k e l s e : n u m s 2. e r a s e ( n u m s 2. b e g i n , n u m s 2. b e g i n + i + 1 ) ; k = k − i − 1 e l s e i f n u m s 2. s i z e < i : i f n u m s 2. b a c k < n u m s 1 [ i ] : n u m s 2. c l e a r b r e a k e l s e : n u m s 1. e r a s e ( n u m s 1. b e g i n , n u m s 1. b e g i n + i + 1 ) ; k = k − i − 1 e l s e : i f ( n u m s 1 [ i ] ≤ n u m s 2 [ i ] ) : n u m s 1. e r a s e ( n u m s 1. b e g i n , n u m s 1. b e g i n + i + 1 ) e l s e : n u m s 2. e r a s e ( n u m s 2. b e g i n , n u m s 2. b e g i n + i + 1 ) k = k − i − 1 if \ 有单个串为空:\\ then: \\ \ \ \ \ \ \ \ \ if \ 奇数: return \ nums[(m+n) / 2]\\ \ \ \ \ \ \ \ \ else \ if 偶数:return \ (nums[(m+n) / 2-1]+nums[(m+n) / 2])/2\\ if\ 两个串不为空:\\ then:\\ \ \ \ \ \ \ \ \ while\ k\ !=1:\\ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ i \Leftarrow (k-1)/2-1\\ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ if\ nums1.size < i:\\ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ if\ nums1.back<nums2[i]:\\ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ nums1.clear\\ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ break\\ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ else:\\ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ nums2.erase(nums2.begin, nums2.begin + i + 1);\\ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ k = k - i - 1\\ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ else \ if\ nums2.size < i:\\ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ if\ nums2.back<nums1[i]:\\ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ nums2.clear\\ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ break\\ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ else:\\ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ nums1.erase(nums1.begin, nums1.begin + i + 1);\\ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ k = k - i - 1\\ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ else:\\ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ if \ (nums1[i] \leq nums2[i]):\\ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ nums1.erase(nums1.begin, nums1.begin + i + 1)\\ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ else:\\ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ nums2.erase(nums2.begin, nums2.begin + i + 1)\\ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ k = k - i - 1 if 有单个串为空:then: if 奇数:return nums[(m+n)/2] else if偶数:return (nums[(m+n)/2−1]+nums[(m+n)/2])/2if 两个串不为空:then: while k !=1: i⇐(k−1)/2−1 if nums1.size<i: if nums1.back<nums2[i]: nums1.clear break else: nums2.erase(nums2.begin,nums2.begin+i+1); k=k−i−1 else if nums2.size<i: if nums2.back<nums1[i]: nums2.clear break else: nums1.erase(nums1.begin,nums1.begin+i+1); k=k−i−1 else: if (nums1[i]≤nums2[i]): nums1.erase(nums1.begin,nums1.begin+i+1) else: nums2.erase(nums2.begin,nums2.begin+i+1) k=k−i−1 - 然后是输出。这里需要考虑只有单个串、有两个串的两种情况。因为 k = 1 k=1 k=1,因此单个串只需要输出最小的一个,两个串分奇偶两种情况。奇数时,只需要输出两个串中最小的一个。偶数时,有两种情况:1. 两个最小数字在同一个串 2. 两个最小数字分别在两个串。逻辑不难,就不贴伪代码了。
代码
#include <iostream>
#include <vector>
using namespace std;
class Solution {
public:
double findMedianSortedArrays(vector<int> &nums1, vector<int> &nums2) {
bool is_odd;
if (nums1.empty() && ~nums2.empty()) {
if (nums2.size() % 2 != 0)
return nums2[nums2.size() / 2];
else
return (nums2[nums2.size() / 2] + nums2[nums2.size() / 2 - 1]) / 2.0;
} else if (nums2.empty() && ~nums1.empty()) {
if (nums1.size() % 2 != 0)
return nums1[nums1.size() / 2];
else
return (nums1[nums1.size() / 2] + nums1[nums1.size() / 2 - 1]) / 2.0;
} else {
// 当数组都为非空的时候
if ((nums1.size() + nums2.size()) % 2 != 0) {
k = (nums1.size() + nums2.size()) / 2 + 1;
is_odd = true;
} else {
k = (nums1.size() + nums2.size()) / 2;
is_odd = false;
}
while (k != 1) {
int i = (k - 1) / 2 - 1;
if (i == -1) {
nums1[0] <= nums2[0] ? nums1.erase(nums1.begin()) : nums2.erase(nums2.begin());
k -= 1;
} else if (nums1.size() <= i + 1) {
if (nums1.back() <= nums2[i]) {
k -= nums1.size();
nums1.clear();
break;
} else {
nums2.erase(nums2.begin(), nums2.begin() + i + 1);
k = k - i - 1;
}
} else if (nums2.size() <= i + 1) {
if (nums2.back() <= nums1[i]) {
k -= nums2.size();
nums2.clear();
break;
} else {
nums1.erase(nums1.begin(), nums1.begin() + i + 1);
k = k - i - 1;
}
} else {
if (nums1[i] <= nums2[i]) {
nums1.erase(nums1.begin(), nums1.begin() + i + 1);
} else {
nums2.erase(nums2.begin(), nums2.begin() + i + 1);
}
k = k - i - 1;
}
}
if (nums1.empty() && ~nums2.empty()) {
if (is_odd)
return nums2[k - 1];
else
return (nums2[k - 1] + nums2[k]) / 2.0;
} else if (nums2.empty() && ~nums1.empty()) {
if (is_odd)
return nums1[k - 1];
else
return (nums1[k - 1] + nums1[k]) / 2.0;
} else {
if (is_odd)
return nums1[0] < nums2[0] ? nums1[0] : nums2[0];
else {
if (nums1.size() > nums2.size()) {
if (nums2.size() > 1) {
if (nums2[1] < nums1[0])
return (nums2[0] + nums2[1]) / 2.0;
else if (nums1[1] < nums2[0])
return (nums1[0] + nums1[1]) / 2.0;
else
return (nums1[0] + nums2[0]) / 2.0;
} else {
if (nums1[1] < nums2[0])
return (nums1[0] + nums1[1]) / 2.0;
else
return (nums1[0] + nums2[0]) / 2.0;
}
} else if (nums1.size() < nums2.size()) {
if (nums1.size() > 1) {
if (nums1[1] < nums2[0])
return (nums1[0] + nums1[1]) / 2.0;
else if (nums2[1] < nums1[0])
return (nums2[0] + nums2[1]) / 2.0;
else
return (nums1[0] + nums2[0]) / 2.0;
} else {
if (nums2[1] < nums1[0])
return (nums2[0] + nums2[1]) / 2.0;
else
return (nums1[0] + nums2[0]) / 2.0;
}
} else {
if (nums1.size() > 1) {
if (nums1[1] <= nums2[0])
return (nums1[0] + nums1[1]) / 2.0;
else if (nums2[1] <= nums1[0])
return (nums2[0] + nums2[1]) / 2.0;
else
return (nums1[0] + nums2[0]) / 2.0;
} else
return (nums1[0] + nums2[0]) / 2.0;
}
}
}
}
}
private:
int k = 0;
};
int main() {
vector<int> b = {5, 6, 8};
vector<int> a = {1, 2, 3, 4, 7, 9, 10};
double result;
Solution solution;
result = solution.findMedianSortedArrays(a, b);
cout << result;
}
总结
这个代码写得比较垃圾,条件分支太多,自己看着也不爽。但是因为是晚上写的,没空简化了。等有空了重写一遍。