参考了别人的英文题解
同时遍历A,B的复杂度是O(n+m),不合题意。这是因为我们寻找total/2大的元素,在每次寻找的过程(迭代)中,只剔除了1个元素;
如果每次剔除一半的元素,就可以保证O(log(m+n))的复杂度了。
第二次:
C++
class Solution {
public:
double findMedianSortedArrays(vector<int>& nums1, vector<int>& nums2) {
int sz = nums1.size() + nums2.size();
if (sz % 2 == 0)
return (find_kth(nums1.data(), nums1.size(), nums2.data(), nums2.size(), sz / 2)
+ find_kth(nums1.data(), nums1.size(), nums2.data(), nums2.size(), sz / 2 + 1)) / 2;
else
return find_kth(nums1.data(), nums1.size(), nums2.data(),
nums2.size(), sz / 2 + 1);
}
private:
double find_kth(const int *a, size_t m, const int *b, size_t n, size_t k) {
if (m > n)
return find_kth(b, n, a, m, k);
else if (m == 0)
return b[k-1];
else if (k == 1)
return min(a[0], b[0]);
size_t pa = min(m, k/2);
size_t pb = k - pa; // find out the evaluation order of commma operator, plz
if (a[pa-1] < b[pb-1])
return find_kth(a+pa, m-pa, b, n, k-pa);
else if (a[pa-1] > b[pb-1])
return find_kth(a, m, b+pb, n-pb, k-pb);
else
return a[pa-1];
}
};
python:
class Solution(object):
def findMedianSortedArrays(self, nums1, nums2):
sz = len(nums1) + len(nums2)
if sz % 2 == 1:
return self.find_kth(nums1, nums2, sz / 2 + 1)
else:
return (self.find_kth(nums1, nums2, sz / 2)
+ self.find_kth(nums1, nums2, sz / 2 + 1)) / 2.0
def find_kth(self, a, b, k):
k = int(k)
m = len(a)
n = len(b)
if m > n:
return self.find_kth(b, a, k)
elif m == 0:
return b[k-1]
elif k == 1:
return min(a[0], b[0])
pa = int(min(m, k/2))
pb = k - pa
if a[pa-1] < b[pb-1]:
return self.find_kth(a[pa:], b, k-pa)
elif a[pa-1] > b[pb-1]:
return self.find_kth(a, b[pb:], k-pb)
else:
return a[pa-1]
第一次:
考虑寻找两个升序数组A, B中第k大的元素,对于A[k/2 - 1], B[k/2 - 1]有三种情况(coding时可保证k的奇偶性不影响答案):
1. 若A[k/2 -1] < B[k/2 - 1], 则A[k/2 -1] < A,B中第k大的元素(下称k-th)
反证:若A[k/2 - 1] >= k-th, 则A中至多有k/2-1个元素(数组下标从0开始记),同时B中至多有k/2-1个元素小于k-th. 从而A,B中至多只有2*(k/2-1) = k-2个元素小于k-th. 矛盾!
从而必有A[k/2-1] < k-th. 所以此时可以去掉A中的前k/2个元素
2. 若A[k/2 - 1] > B[k/2 - 1], 则B[k/2 - 1] < A,B中第k大的元素(下称k-th)
同理。此时可以去掉数组B中的前k/2个元素。
3. A[k/2 - 1] == B[k/2 - 1]
此时A[k/2-1], B[k/2-1]都是A,B中第k大的元素 ,亦即找到了答案。
代码:
class Solution
{
public:
double findMedianSortedArrays(int A[], int m, int B[], int n)
{
if ((m+n) % 2 == 1)
{
return find_kth(A,m,B,n,(m+n)/2+1);
} else // even
{
return (find_kth(A,m,B,n,(m+n)/2)+find_kth(A,m,B,n,(m+n)/2+1)) / 2.0;
}
}
private:
int find_kth(int A[], int m, int B[], int n, int k)
{
if (m == 0)
{
return B[k-1];
} else if (m > n) // maintain m < n
{
return find_kth(B, n, A, m, k);
} else if (k == 1)
{
return min(A[0], B[0]);
}
int pa = min(k/2, m), pb = k-pa;
if (A[pa-1] < B[pb-1])
{
return find_kth(A+pa, m-pa, B, n, k-pa); // maintain pointers and k-th here
} else if (A[pa-1] > B[pb-1])
{
return find_kth(A, m, B+pb, n-pb, k-pb); // maintain pointers and k-th here
} else
{
return A[pa-1];
}
}
};