一、概述
给出两个有序数组,找它们的中位数。
时间复杂度要求O(log(m+n))。
二、分析
1、我的方法(没写代码,边界条件太多)
一看时间复杂度要求log,肯定是用二分。那么如何二分呢?想到正常二分法是一次劈一半长度,因此照着平常的思路去做,将长度减半。
维护四个变量,start1,start2,end1,end2。分别表示目前两个数组还剩的部分。
首先找到两个数组各自的中位数mid1和mid2。如果是偶数,要前面那个。这样一来,两个数组分别被分为三部分:
L1 < mid1 < R1
L2 < mid2 < R2
如果mid1>mid2,那么由上图可知,L2<mid2<mid1<R1。
注意到L2<mid2<R2,那么就有R1、R2、mid1、mid2都比L2大,因此中位数不可能在L2中。
这个可以举例,更容易理解。
1 3 5 7 9 L1:1 3 mid1:5 R1:7 9
2 4 6 8 10 L2:2 4 mid2:6 R2:8 10
有mid1<mid2,也就是说,对于L2,有5,6,7 9,8 10都比它们大,有六个数比他们大,中位数不可能出在里面,删去。
把L2删掉,即令start2=mid2。
同样的,对于R1,有L1、L2、mid1、mid2小于R1,中位数不可能在R1中,把R1删掉,即令end1=mid1。
这样一来,删掉L2,R1,相当于总长度减了半。
对于mid1<mid2的情况也是一样。
一直到有一个数组长度为2或者1的时候。就不能再这么做了。
如果继续这样做,可能会出现减多了的情况,如下例:
4 5 6 7 8 9 10 11
1 2 3
mid1=7,mid2=2,删掉1和8,9,10,11
4 5 6 7
2 3
mid1=5,mid2=2,删掉6,7
然而结果应该是6
这时就删多了。
因此当长度为1或2,就应退出循环。
也就是说,我这种算法,最多能把问题退化成一个数组和一个或两个数组合,求中位数。还是麻烦,情况很多。
2、较好的方法
“找第k个数”,也就是“找k-1个数,它们比第k个数小”,也就是“找最小的k-1个数”。
首先明确一点,第k个数,指的是从第1个开始数,不是第0个。
比如说:
1 3 5 7 9
2 4 6 8
第k个指的是第5个数,也就是5。
思路如下:
找中位数→找第len/2的数→不好找→退而求其次→找第几个数好找呢?→找第一个数→就是两个数组的第一个中的较小那个
→只有一个数组→直接找另一个数组的第k个数
因此只要最后将“找第k个数”退化成“找第1个数”或者“令一个数组为空”即可。
二分法不再二分长度,而是二分k。
还有一个好处,就是不必考虑奇数长度或者偶数长度的复杂性了。奇数找第k个,偶数找第k个和k+1个即可。十分容易理解。
其实思路与我的思路有相似之处:
找第k个,选择两个数组中第k/2位置的数来比较,k/2是向下取整,先不管为什么分别找第k/2的数,先考虑找到了之后怎么处理:
五种情况:
1、数组a的大于数组b的。
同样以L1 mid1 R1 L2 mid2 R2来处理。这里则有mid1>mid2,因此L2和mid2可以删掉。也即是“最小的k-1个数”中,我们已经找到了L2+1个。然后找k-L2-1个就好了。
2、数组a的小于数组b的。
同样以L1 mid1 R1 L2 mid2 R2来处理。这里则有mid1<mid2,因此L1和mid1可以删掉。也即是“最小的k-1个数”中,我们已经找到了L1+1个。然后找k-L1-1个就好了。
3、数组a的等于数组b的。
同样以L1 mid1 R1 L2 mid2 R2来处理。这里则有mid1=mid2,因此L1、L2可以删掉。也即是“最小的k-1个数”中,我们已经找到了L1+L2个。然后找k-L1-L2个就好了。注意这里L1和L2是等于k/2的,因此中位数就直接是mid1或者mid2了。
4、数组a<k/2
即数组a的长度不够找k/2个了,那么说明中位数一定不在L2中,把L2和mid2删除即可。
5、数组b<k/2
即数组b的长度不够找k/2个了,那么说明中位数一定不在L1中,把L1和mid1删除即可。
举以下几个例子说明:
例一:
2 4 6 8
1 3 5 7 9
k=5,找第5个数。
k/2=2,为简便,令数组a的mid1为第2个,mid2为第3个。
4<5,于是删掉2 4,找第5个变成找第3个。
6 8
1 3 5 7 9
k/2=1,为简便,mid1为第1个,mid2为第2个。
6>3,于是删掉1 3,找第3个变成找第1个。
6 8
5 7 9
找较小的,也就是5,找到了。
例二:
1 3 5 7 9
2 4 6 8 10
k=5,找第k+1也就是第6个。
k/2=3,mid1=5,mid2=6,删掉1 3 5,找第3个。
7 9
2 4 6 8 10
k/2=1,mid1=7,mid2=4,删掉2 4,找第1个。
7 9
6 8 10
找较小的,找6,找到了。
例三:
0 1 2 3
4
k=3,找第k个也就是第3个。
k/2=1,找数组a的第1个,数组b的第2个,数组b没有第2个,因此直接把数组a的第1个删除。
1 2 3
4
k=2,k/2=1,mid1=1,mid2=4,删除mid1,找第1个。
2 3
4
找较小的,也就是2,找到了。
然后解释为什么是二分k:
从上面我们可以知道,整个过程就是往一个k-1大的桶里扔数,扔满了就完事。那么如何让每次扔的期望最多呢?每次扔k/2个,如果选择k/3,那么好的时候一次扔k*2/3个,坏的时候就得扔k/3个,效率不够平均,不好,因此选择k/2。
由此可以写出代码:
采用递归,函数如下:
double findKth(vector<int>& nums1, vector<int>& nums2,int start1,int start2,int len1,int len2,int k)
{
if(len1>len2)
return findKth(nums2,nums1,start2,start1,len2,len1,k);
if(len1==0)
return nums2[start2+k-1];
if(k==1)
{
return min(nums1[start1],nums2[start2]);
}
int p=min(k/2,len1);
int q=k-p;
if(nums1[start1+p-1]>nums2[start2+q-1])
return findKth(nums1,nums2,start1,start2+q,len1,len2-q,k-q);
else if(nums1[start1+p-1]<nums2[start2+q-1])
return findKth(nums1,nums2,start1+p,start2,len1-p,len2,k-p);
else
return nums1[start1+p-1];
}
参数有以下几个:数组a,数组b,a的起始位置(也就是“删除”操作执行后的效果,毕竟不能真删了),b的起始位置,a的产股,b的长度,k值。
为了简化情况,我们一直令较短的为数组a,这样一来,对a找k/2就很直观,不必考虑数组b的长度是否不满足k/2了,也不必考虑数组2的长度是否为0了。
因此首先判断长度,若数组a的长度大于数组b,那么就全掉换。
然后是第一个注意点:
先判断数组a的长度是否为0,为0则直接返回数组b的第k个元素;若长度不为0,再判断k是否等于1。
若是反过来,是这样:先判断k是否等于1,等于1则返回两个数组头中较小的;然后判断是否有数组长度为0。
注意这样一种情况:
我们看这里的start1,其含义是“数组a的起始位置”,是这样得到的,用上一次递归的起始位置加上k/2,也就是“删去L1和mid1”,那么如果存在这样一种情况,数组a只有一个元素,k等于2,而且删去了mid1,那么数组a就不存在元素了,接下来的下一次递归,start1就会溢出,这时进入判断“k等于1”,判断成功,返回值有可能会溢出。见下例:
1 2
3 4
k=3,k/2=1,(k=3是因为偶数个要找第2个和第3个)
mid1=1,mid2=4,删去1;
2
3 4
k=2,k/2=1,
mid1=2,mid2=3,删去2;
-
3 4
k=1,然而数组a不存在元素,出错。
反过来判断则可以防止这种情况。
也就是说:先判断长度,再查看元素,这一点要记住,不光本题,许多题目都要注意这一点。
第二个注意点:p应该是k/2和len1中较小的那个。
同样是为了防止溢出。
然后递归即可。
主函数如下:
double findMedianSortedArrays(vector<int>& nums1, vector<int>& nums2) {
int m=nums1.size(),n=nums2.size();
int k=(m+n)/2;
int len=m+n;
if(len%2==0)
return (findKth(nums1,nums2,0,0,m,n,k)+findKth(nums1,nums2,0,0,m,n,k+1))/2;
else
return findKth(nums1,nums2,0,0,m,n,k+1);
}
省去了让人头大的奇数偶数,看一个还是看两个的问题,十分让人舒适。
三、总结
很巧妙的一道题,算法思想应理解。
巧妙之处一大一小:
大的在于二分k而不是二分长度,小的在于一直令短的为数组a简化情况。
PS:代码如下:
class Solution {
private:
double findKth(vector<int>& nums1, vector<int>& nums2,int start1,int start2,int len1,int len2,int k)
{
if(len1>len2)
return findKth(nums2,nums1,start2,start1,len2,len1,k);
if(len1==0)
return nums2[start2+k-1];
if(k==1)
{
return min(nums1[start1],nums2[start2]);
}
int p=min(k/2,len1);
int q=k-p;
if(nums1[start1+p-1]>nums2[start2+q-1])
return findKth(nums1,nums2,start1,start2+q,len1,len2-q,k-q);
else if(nums1[start1+p-1]<nums2[start2+q-1])
return findKth(nums1,nums2,start1+p,start2,len1-p,len2,k-p);
else
return nums1[start1+p-1];
}
public:
double findMedianSortedArrays(vector<int>& nums1, vector<int>& nums2) {
int m=nums1.size(),n=nums2.size();
int k=(m+n)/2;
int len=m+n;
if(len%2==0)
return (findKth(nums1,nums2,0,0,m,n,k)+findKth(nums1,nums2,0,0,m,n,k+1))/2;
else
return findKth(nums1,nums2,0,0,m,n,k+1);
}
};