1.题目描述
2.选题原因
本周为两节习题课,在习题课的最后,讨论到了这一题,没有讨论出结果。本来有一些思路,恰好看到了这一题,于是将其解决。
3.题目分析及算法
3.1分析
这道题本来解法不难,正常思路有两种,一种是直接讲两个数组排序,然后直接找到中间数值,其算法复杂度为
O(m + n)
;另一种是从两个数组的首部开始找,每找到一个小数,指针就向下移动一格,直到移动了(m + n) / 2
其复杂度也为O(m + n)
。因此两种算法都不符合题目O(log(m + n))
的要求。
我们看题目要求,很容易联想到,log
的复杂度,说明要用到二分
之类的方法。将本题类比做寻找k
值,(k = (m + n) / 2
)。认真思考,二分的实质即是:每次剔除一般的值,一次递归,一直剔除到需要的值出现。如果是在一个数组中寻找,很简单,每次将两份数值中的一份删去即可,但是如果是两个数组,那又如何实现呢?
想象将两个数组同时二分,分成4段,能否得到可定可以删去的一段呢?是可以的,举个简单的例子:
我们来比较
x
与y
的大小,如果x > y
那么就说明mid n
左边的一段一定是可以删去的,若是x < y
,那么结果相反,删去的即是mid m
左边的一段。一次操作过后,剩下的即是在剩下的三段中间找寻k - mid n/mid m
大的值。
那么新的问题又出现了:如果在递归的过程中,某一段的mid m/mid n > k
了应该怎么办?其实也很好解决,在我们二分的时候,将步长设置为k / 2
即可。
现在考虑,完整的算法应该如何划分,选择。
这样,每次我们就能够从中间删除
k / 2
长度的区间。
但是我们的问题又来了,如果数组不够k / 2
的长度呢?不就会越界访问了吗?因此,区间设置应该为min(k / 2, m_length, n_length)
。
接下来,我们要思考,终点在哪里。实际上,终点就是
k == 1
或者m\n长度为0
的时候。
3.2算法
(m + n)
是奇数,寻找(m + n) / 2
;否则寻找((m + n) / 2 + (m + n) / 2 + 1) / 2
。
寻找k
位算法:
- 若有数组为空,则结果为非空数组的第
k
位- 若
k == 1
,则结果为两个数组首位中较小的那一个- 选定坐标
index1 = min(k / 2, n_length, m_length)
与index2 = min(k / 2, n_length, m_length)
- 比较两个坐标的值,舍弃小的值所在数组坐标左边所有值。
- 结果再次递归,寻找第
k - 舍弃数字个数
的值
4.核心代码
4.1出口
//如果有数组为空,在剩下的数组中寻找即可
if (mlen == 0) {
return n[nstart + k - 1];
} else if (nlen == 0) {
return m[mstart + k - 1];
}
//只找一位,避免死循环
if(k == 1) {
return min(m[mstart], n[nstart]);
}
4.2比较过程
int mmid, nmid;
//避免越界,又要保证删去尽可能多的数据,因此选择长的数组作为起始分界点
if (nlen > mlen) {
mmid = min(k / 2, mlen);
nmid = k - mmid;
} else {
nmid = min(k / 2, nlen);
mmid = k - nmid;
}
//判断
if(m[mstart + mmid - 1] < n[nstart + nmid - 1]) {
return findKth(m, n, mstart + mmid, nstart, mlen - mmid, nlen, k - mmid);
} else {
if(m[mstart + mmid - 1] > n[nstart + nmid - 1]) {
return findKth(m, n, mstart, nstart + nmid, mlen, nlen - nmid, k - nmid);
} else {
return m[mstart + mmid - 1];
}
}
5.结果
6.源代码
double findMedianSortedArrays(vector<int>& nums1, vector<int>& nums2) {
int m = nums1.size();
int n = nums2.size();
int len = m + n;
//判断中位数有一位还是两位
if (len % 2 != 0)
return findKth(nums1, nums2, 0, 0, m, n, len / 2 + 1);
else
{
return (findKth(nums1, nums2, 0, 0, m, n, len / 2) + findKth(nums1, nums2, 0, 0, m, n, len / 2 + 1)) / 2;
}
}
double findKth(vector<int> m, vector<int> n, int mstart, int nstart, int mlen, int nlen, int k)
{
//如果有数组为空,在剩下的数组中寻找即可
if (mlen == 0) {
return n[nstart + k - 1];
} else if (nlen == 0) {
return m[mstart + k - 1];
}
//只找一位,避免死循环
if(k == 1) {
return min(m[mstart], n[nstart]);
}
int mmid, nmid;
//避免越界,又要保证删去尽可能多的数据,因此选择长的数组作为起始分界点
if (nlen > mlen) {
mmid = min(k / 2, mlen);
nmid = k - mmid;
} else {
nmid = min(k / 2, nlen);
mmid = k - nmid;
}
//判断
if(m[mstart + mmid - 1] < n[nstart + nmid - 1]) {
return findKth(m, n, mstart + mmid, nstart, mlen - mmid, nlen, k - mmid);
} else {
if(m[mstart + mmid - 1] > n[nstart + nmid - 1]) {
return findKth(m, n, mstart, nstart + nmid, mlen, nlen - nmid, k - nmid);
} else {
return m[mstart + mmid - 1];
}
}
}