两个有序数组的中位数问题
———leetcode 第4题
来自leetcode上的答案的整理。
Description:
Median of Two Sorted Arrays
There are two sorted arrays nums1 and nums2 of size m and n respectively.
Find the median of the two sorted arrays. The overall run time complexity should be O(log (m+n)).
You may assume nums1 and nums2 cannot be both empty.
描述:有两个非降有序数组,大小分别是m,n,找这两个数组的中位数。假设这两个数组不能全为空。
a) 暴力破解方式
开辟一个新数组,大小为m+n,用归并的方式将两个数组重新排序的一个数组内,然后找中位数,代码比较容易。
b) 不断减少问题规模得到解
问题抽象:
对于两个有序数组A,B,要找集合{A, B}中kth的数,这里找的是中位数( k=(m+n+1)/2 or k=( (m+n+1)/2+(m+n+2)/2 )/2 ).
A的长度为m,B的长度为n,确保m<=n(如果不满足),首先我们分别找这两个数组的k/2位置的元素,我们将会遇到两种情况:
-
A[k/2]>B[k/2],因为是有序的,那么中位数不会位于B[k/2]之前的部分,变成了一个问题规模更小的子问题。
B[0], B[1], …, B[k/2] | B[k/2+1], B[k/2+2], …, B[n-1]
这时候A没有约减,考虑到我们去掉B[0]~B[k/2],去掉了k/2个最小的数了,我们要在新的数组内找(k-k/2)大的数。
然后不停递归比较约减下去。
-
A[k/2]<=B[k/2],因为是有序的,那么中位数不会位于A[k/2]之前的部分,问题规模得以简化。
A[0], A[1], …, A[k/2] | A[k/2+1], A[k/2+2], …, A[m-1]
这时候B没有约减,同样也是去掉了k/2个最小的数了,我们要在新的数组内找(k=k-k/2)大的数。
然后不停递归比较约减下去。
当最后k-k/2=1的时候,表示我们找到了我们要找的那个数。
伪代码:
function findMedianSortedArrays(A,B):
m=A.size;
n=B.size;
left=(m+n+1)>>1;
right=(m+n+2)>>1;
return (getkth(A,m,B,n,l)+getkth(A,m,B,n,r))/2.0;
end function;
function getkth(s[],m,l[],n,k):
if m>n then:
return getkth(l,n,s,m,k);
end if;
if m==0 then:
//一个数组全部约减掉,等于在另一个数组找第k大的数。
return l[k-1]
end if;
if k==1 then:
//找最小的数,表示我们约减掉的已经足够多了
//只需要比较约减后的两个数组的首元素哪个最小
return min(s[0], l[0]);
end if;
//决定约减哪个数组
int i=min(m,k/2),j=min(n,k/2);
if s[i-1]>l[j-1] then:
return getkth(s,m, l+j,n-j,k-j);
else
return getkth(s+i,m-i,l,n,k-i);
return 0;
end function;
正确性分析:
正确性内含在伪代码注释和问题分析中的。
时间复杂度分析:
子问题简化式子: T ( n ) = T ( 3 n 4 ) + k , k 是 个 常 数 T(n)=T(\frac{3n}{4})+k,k是个常数 T(n)=T(43n)+k,k是个常数,这里的n=m+n,原问题的规模大小是两个数组长度。
故时间复杂度为: O ( log ( m + n ) ) O(\log{(m+n)}) O(log(m+n))
c) 基于中位数的二分查找
首先理解中位数的用处。
从统计上,中位数被用于:将一个集合划分为两个等长的子集,其中一个子集的元素都大于另一个子集。
对数组A,在i位置切分成两部分,对数组B,在j位置切成两部分,
把left_A和left_B放在一个集合,把right_A和 right_B放进另一个集合。把它们命名为: left_part和right_part:
left_part | right_part
A[0], A[1], ..., A[i-1] | A[i], A[i+1], ..., A[m-1]
B[0], B[1], ..., B[j-1] | B[j], B[j+1], ..., B[n-1]
同时必须确保这两个集合的长度相等,一个集合内的元素总是大于另一个集合内的元素,即:
- len(left_part)=len(right_part)
- max(left_part)≤min(right_part)
然后,我们就可以把{A, B}内所有元素划分成两个子集,其中一个子集内的所有元素总是大于另一个子集。然后;
m
e
d
i
a
n
=
m
a
x
(
l
e
f
t
_
p
a
r
t
)
+
m
i
n
(
r
i
g
h
t
_
p
a
r
t
)
2
median= \frac{max(left\_part)+min(right\_part)}{2}
median=2max(left_part)+min(right_part)
为了确保这两个条件成立,我们只需确保:
i+ j = m−i+n−j (or: m - i + n - j + 1)
如果n ≥ \geq ≥m,我们需要这样设置;i=0~m, j= m + n + 1 2 − i \frac{m + n + 1}{2} - i 2m+n+1−i
B[j-1] ≤ \leq ≤ A[i] 和 A[i-1] ≤ \leq ≤ B[j]
ps. 1 简单起见,我们假设A[i-1],B[j-1],A[i],B[j]总是合法,尽管i=0,i=m,j=0或者j=n这种情况。后面将会告知如何处理这些边缘值。
ps. 2 为什么n ≥ \ge ≥m? 因为我们必须确保j是非负的,因为i=0~m,j= m + n + 1 2 − i \frac{m + n + 1}{2} - i 2m+n+1−i,如果n<m,j可能会是负值,最终导致错误的结果。
所以,我们需要做的所有事情是:
在区间[0,m]搜索一个i,这个i需要满足:
B[j−1]≤A[i] ,A[i-1] ≤ \le ≤B[j],j= m + n + 1 2 − i \frac{m + n + 1}{2} - i 2m+n+1−i
我们可以二分搜索,步骤如下所述:
-
设置 imin=0, imax=m, 然后开始在区间[imin, imax]内搜索.
-
设置 i = i m i n + i m a x 2 , j = m + n + 1 2 − i i=\frac{imin+imax}{2},j=\frac{m+n+1}{2}-i i=2imin+imax,j=2m+n+1−i
-
现在我们有len(left_part)=len(right_part).我们将会遇到三种情况
-
B[j-1] ≤ \le ≤ A[i] 和 A[i-1] ≤ \le ≤ B[j]
这说明我们搜索到了目标i,停止搜索。
-
B[j-1]>A[i]
说明A[i]太小。我们必须调整i使B[j-1] ≤ \le ≤A[i]满足。
我们能够增加i吗?
是的,因为i增加,j将减少.故B[j-1]减小,A[i]增大。然后B[j-1] ≤ \le ≤A[i]可以被满足。
我们能够减少i吗?
不可以,因为i减小,j将会增大。故B[j-1]增大,A[i]减少,B[j-1] ≤ \le ≤A[i]可以满足
所以我们必须增大i,也就是说,我们必须调整搜索范围[i+1,imax].
所以设置imin=i+1,然后进行第二步.
-
A[i-1]>B[j]:
说明A[i-1]太大,我们必须减小i,使A[i-1] ≤ \le ≤B[j]得以满足。也就是说,我们必须调整搜索范围。
因此我们设置imax=i-1,然后进行2.
-
当找到目标i,中位数是:
max(A[i−1],B[j−1]),m+n是奇数
m a x ( A [ i − 1 ] , B [ j − 1 ] ) + m i n ( A [ i ] , B [ j ] ) 2 \frac{max(A[i−1],B[j−1])+min(A[i],B[j])}{2} 2max(A[i−1],B[j−1])+min(A[i],B[j]),m+n是偶数。
考虑边际值的情况,i=0,i=m,j=0,j=n 此时A[i-1],B[j-1],A[i],B[j]可能不存在。事实上这种情况比你想的更容易。
我们需要确保max(left_part) ≤ \le ≤min(right_part).因此,如果i和j不是边际值(意味着A[i-1],B[j-1],A[i],B[j]都存在),此时我们必须check B[j-1] ≤ \le ≤A[i],A[i-1] ≤ \le ≤B[j].但是一旦A[i-1],B[j-1],A[i],B[j]中有一些不存在,我们将不必check这两个情况中的一个或所有。例如,如果i=0,A[i-1]不存在,然后我们就不必checkA[i-1] ≤ \le ≤B[j],所以我们需要做的是:
在[0, m]的区间找一个i,满足以下:
(j=0 or i=m or B[j-1] ≤ \le ≤ A[i]) and
(i=0 or j=n or A[i-1] ≤ \le ≤B[j]),此时 j= m + n + 1 2 − i \frac{m+n+1}{2}-i 2m+n+1−i
在一个搜索循环中,我们可能遭遇三种情况:
(j=0 or i=m or B[j-1] ≤ \le ≤ A[i]) and
(i=0 or j=n or A[i-1] ≤ \le ≤B[j])说明i是最好的,我们可以停止搜索。
j>0 and i < m and B[j - 1] > A[i]
说明i太小,我们必须增大i
i>0 and j < n and A[i−1]>B[j]
说明i太大,我们减小i
对于情况2,3,我们不必去check是否j>0,j<n.
伪代码:
function midOfTwoSortArr(A[],B[]):
//初始化
int m=A.size;
int n=B.size;
if m>n then:
int[] temp=A;A=B;B=temp;
int tmp=m; m=n; n=tmp;
end if;
imin=0;imax=m;
halflen=(m+n+1)>>1;
//第一种情况
while(imin<imax):
int i=(imin+imax)>>1;
int j=halflen-i
if B[j-1]>A[i] then:
//i太小,增大i
imin=i+1;
else if A[i-1]>B[j] then:
//i太大,减小i
imax=i-1;
else:
int maxleft=0;
//找到目标i
//处理边际值
if i==0 then: maxleft=B[j-1];
else if j==0 then: maxleft=A[i-1];
else: maxleft=max(A[i-1],B[i-1]);
//根据i得到
if (m+n)%2==1 then:
//是奇数
return maxleft;
end if;
int minright=0;
if i==m then: minright=B[j];
else if j==n then: minright=A[i];
else: minright=min(B[j],A[i]);
end if;
return (maxleft+minright)<<1;
end if;
end while;
end function
AC的cpp代码:
class Solution {
public:
double findMedianSortedArrays(vector<int>& nums1, vector<int>& nums2) {
int m=nums1.size();
int n=nums2.size();
if (m>n){
vector<int> temp=nums1; nums1=nums2;nums2=temp;
int tmp=m; m=n; n=tmp;
}
int imin=0;
int imax=m;
int halflen=(m+n+1) >> 1;
while (imin<=imax){
int i = (imin + imax) >> 1;
int j = halflen - i;
if ((i<imax) && (nums2[j-1]>nums1[i])){
//i太小,增大i
imin=i+1;
}else if((i>imin)&&(nums1[i-1]>nums2[j])){
//i太大,减小i
imax=i-1;
}else{
int maxleft=0;
if(i==0)
maxleft=nums2[j-1];
else if(j==0)
maxleft=nums1[i-1];
else
maxleft=max(nums1[i-1],nums2[j-1]);
if((m+n)%2==1){
//m+n是奇数
return maxleft;
}
int minright=0;
if (i==m)
minright=nums2[j];
else if(j==n)
minright=nums1[i];
else
minright=min(nums2[j],nums1[i]);
return (maxleft+minright)/2.0;
}
}
return 0.0;
}
};
正确性分析:
略。
时间复杂度分析:
子问题简化式子: 是 常 数 T ( n ) = T ( n 2 ) + k , k 是 常 数 是常数T(n)=T(\frac{n}{2})+k,k是常数 是常数T(n)=T(2n)+k,k是常数,这里的n是两个数组长度最小的那个长度(从代码中可以看出)。
故时间复杂度: O ( log min ( m , n ) ) O(\log{\min{(m,n)}}) O(logmin(m,n))