做到这题的时候,开始觉得不是很难,毕竟O(m+n)的遍历算法是比较容易的,题目要的log(m+n)的算法,乍一看也是很容易想到的,只要我们用QuickSort或者说二分法的想法找到这个rank=(m+n)/2的点就可以了。
但是做下去的时候,发现边界条件太复杂了,很难理清楚,特别是奇数和偶数带来的0.5的偏差,让二分做不到真正的二分。
相信看到这篇博客的同学也是刷过题,知道其中的难度的,因此本文就主要简单阐述一下我解题的想法。
注:这个解法在处理边界条件时不是很优雅,但是我觉得能够短时间让自己能够理清楚的算法有时候能够让我们解题更快。
算法的重点思想是这样的(median的值为X),在一般情况下:
- 如果nums1中有k1个点小于X,那么nums2中必须有k2=half-k1个点小于X。
- 如果P是nums1中位置为k1’的点的值,Q是nums2中对应的位置为half-k1’的点的值:
–>1. 如果P等于Q,很完美,两个数组中小于这个值的数字一共half个
–>2. 如果P小于Q,说明k1’对应的值还不够大,真正的中点值应该介于PQ之间,k1在k1’右侧
–>3. 如果P大于Q,说明k1’对应的值太大了,真正的中点值应该介于PQ之间,k1在k1’左侧
首先我们注意到,假设对于median X,nums1和nums2中有k1和k2个点会比X小,那么k1+k2=(m+n)/2,那么我们可以用二分法找到k1和k2。
我们知道二分法的重点在于决定选择下一次迭代在中点左侧还是右侧,那么我们可以用k1+k2=(m+n)/2这个重要的公式来决定。
首先确定k1,我们初始化
double half=(m+n)/2.0;
// 应该有多少个点小于X
int begin=0;
int end=nums1.size()-1;
double mid=(begin+end)/2.0
// mid其实就是每次迭代的k1
double nums_mid=getMid(nums1, mid);
// 这里我们用一个函数得到中点的值,这个值可能是int,也可能是x.5这样的数字
注意我们的mid是一个double的数字,然后我们根据mid的值,可以知道nums1中小于nums_mid的数有几个,这需要根据mid=x.5或mid=x.0决定。我们又知道k1+k2=(m+n)/2,那么k2=half-k1,这样我们求出这次迭代的k2’。
知道了k2’以后,我们可以知道nums2中应当有多少个值小于当前的nums_mid,我们的理想数量是k2’个,但是现实往往没有这么理想,注意:
- k2’<0,说明我们的k1’太大了,nums2中小于nums_mid的数量都已经超过了half。
- k2’>nums2.size()-1,说明我们的k1’太小了,nums2所有的数据小于nums_mid都无法填满。
- 其他情况下,我们需要根据k2’对应的点的数值来决定我们的k1’选取是否合理。
既然知道了到底有多少个数据比nums_mid小,那么我们接下去要决定k1二分法的迭代是向左还是向右了,有这样几种情形:
- k2’<0,显然k2’需要向右移动
- k2’>nums2.size()-1,显然我们的nums_mid太大了,向左二分。
- 正常情况下,我们需要根据k2对应的数值nums_mid的关系来确定向左还是向右:
–>1.如果num_k2大于nums_mid,说明我们的nums_mid不够大,向右。
–>2.如果num_k2小于nums_mid,说明的nums_mid太大了,向左。
用同样的方法我们可以找到k2。
接下来我用了一个不是很优雅但是很清楚的解法找到真正的中值:
有了k1和k2以后,由于k1和k2都有可能是x.0或x.5这样的数字,因此我们不知道到底哪个或者哪两个点才是中值点。这非常难以判断,但是没关系,我们的算法已经去除了绝大部分的数据点。接下去我们可以找到我们的候选中点集合candidate。
与k1和k2相关的点有几个呢?答案是不超过6个,最复杂的情况下,是
k1-1, k1, k1+1;k2-1, k2, k2+1
,如果k1和k2都是x.5的形式,那就只有4个点int(k1), int(k1)+1;int(k2), int(k2)+1
,有些情况下,我们的点可能比4个还少,比如k1和k2比0小或是比数组的总长度更大的时候。这个边界条件是容易确定的,尽管编码比较长,但是相对容易。
找到candidate集合以后,我们还可以知道比这些candidate小的点有多少个(毕竟已知k1,k2),定义为small。
最后我们只需要找到candidate里面第half-small个点就可以了,当然,可能是两个点。
这样做下来我们的时间复杂度是
O(log(m))+O(log(n))+O(6)=O(log(mn))=O(log(m+n))
注:可以证明
1. log(mn)>log(m+n)
2. 2log(m+n)=log(m+n)^2=log(m^2+2mn+n^2)>log(mn)
完整代码如下:
class Solution {
public:
double getMid(vector<int>& nums, int begin, int end)
{
return (begin+end)%2==0?nums[(begin+end)/2]:(nums[(begin+end)/2]+nums[(begin+end)/2+1])/2.0;
}
double getSmallerNum(int begin, int end)
{
return (begin + end) / 2.0;
}
double getIndexedNum(vector<int>& nums, double idx)
{
if (idx > nums.size() - 1)
{
return -1;
}
else if (idx < 0)
{
return -2;
}
if (idx - int(idx) > 0)
{
return (nums[int(idx)] + nums[int(idx)+1])/2.0;
}
else {
return nums[int(idx)];
}
}
double findPosition(vector<int>& nums1, vector<int>& nums2)
{
int total = nums1.size() + nums2.size();
double half = (total-1) / 2.0;
int begin = 0, end = nums1.size()-1;
while(1){
int pre1 = begin, pre2 = end;
double mid = getMid(nums1, begin, end);
double midx = getSmallerNum(begin, end);
double idx2 = half - midx;
double idxedNum = getIndexedNum(nums2, idx2);
if (idxedNum == -1 || idxedNum > mid)
{
begin = (begin + end) / 2;
}
else if (idxedNum == -2 || idxedNum < mid)
{
end = (begin + end) / 2;
}
else {
return midx;
}
if (begin == pre1 && end == pre2)
{
return (begin + end) / 2.0;
}
}
}
double findMedianSortedArrays(vector<int>& nums1, vector<int>& nums2) {
if (nums1.size() == 0)
{
return getMid(nums2, 0, nums2.size()-1);
}
if (nums2.size() == 0)
{
return getMid(nums1, 0, nums1.size()-1);
}
double idx1 = findPosition(nums1, nums2);
double idx2 = findPosition(nums2, nums1);
int small = int(idx1) + int(idx2);
int candidate[6];
for (int i = 0; i < 6; ++i)
{
candidate[i] = 1e8;
}
candidate[0] = nums1[int(idx1)];
if (idx1 - int(idx1) > 0)
{
candidate[1] = nums1[int(idx1)+1];
}
else{
int intidx = int(idx1);
if (intidx - 1 >= 0)
{
candidate[1] = nums1[int(idx1) - 1];
small -= 1;
}
if (intidx + 1 < nums1.size())
{
candidate[2] = nums1[int(idx1) + 1];
}
}
candidate[3] = nums2[int(idx2)];
if (idx2 - int(idx2) > 0)
{
candidate[4] = nums2[int(idx2)+1];
}
else{
int intidx = int(idx2);
if (intidx - 1 >= 0)
{
candidate[4] = nums2[int(idx2) - 1];
small -= 1;
}
if (intidx + 1 < nums2.size())
{
candidate[5] = nums2[int(idx2) + 1];
}
}
sort(candidate, candidate+6);
double half = (nums1.size() + nums2.size()-1) / 2.0;
int temp = int(half) - small;
if (half - int(half) > 0)
{
return (candidate[temp] + candidate[temp+1])/2.0;
}
else {
return candidate[temp];
}
}
};