C版本
int getMax(int a, int b)
{
return a>=b?a:b;
}
int getMin(int a, int b)
{
return a>=b?b:a;
}
double findMedianSortedArrays(int* nums1, int nums1Size, int* nums2, int nums2Size) {
int m = nums1Size;
int n = nums2Size;
if(m > n){
int *temp = nums1; nums1 = nums2; nums2 = temp;
int tmp = m; m = n; n = tmp;
}
int *A = nums1;
int *B = nums2;
int iMin = 0, iMax = m, halfLen = (m + n + 1)/2;
while(iMin <= iMax){
int i = (iMin + iMax)/2;
int j = halfLen - i;
if(i < iMax && B[j-1] > A[i]){
iMin = i + 1; //i is too small
}
else if(i > iMin && A[i-1] > B[j]){
iMax = i - 1; //i is too big
}
else { // i is perfect
int maxLeft = 0;
if (i == 0){maxLeft = B[j-1];}
else if (j == 0){maxLeft = A[i-1];}
else {maxLeft = getMax(A[i-1],B[j-1]);}
if ((m + n) % 2 ==1){return maxLeft;}
int minRight = 0;
if (i == m){minRight = B[j];}
else if (j == n){minRight = A[i];}
else {minRight = getMin(B[j],A[i]);}
return (maxLeft + minRight)/2.0;
}
}
return 0.0;
}
C++版本
class Solution {
public:
double findMedianSortedArrays(vector<int>& nums1, vector<int>& nums2) {
int m = nums1.size(), n = nums2.size(), left = (m + n + 1) / 2, right = (m + n + 2) / 2;
return (findKth(nums1, 0, nums2, 0, left) + findKth(nums1, 0, nums2, 0, right)) / 2.0;
}
int findKth(vector<int>& nums1, int i, vector<int>& nums2, int j, int k) {
if (i >= nums1.size()) return nums2[j + k - 1];
if (j >= nums2.size()) return nums1[i + k - 1];
if (k == 1) return min(nums1[i], nums2[j]);
int midVal1 = (i + k / 2 - 1 < nums1.size()) ? nums1[i + k / 2 - 1] : INT_MAX;
int midVal2 = (j + k / 2 - 1 < nums2.size()) ? nums2[j + k / 2 - 1] : INT_MAX;
if (midVal1 < midVal2) {
return findKth(nums1, i + k / 2, nums2, j, k - k / 2);
} else {
return findKth(nums1, i, nums2, j + k / 2, k - k / 2);
}
}
};