题面
There are two sorted arrays nums1 and nums2 of size m and n respectively.
Find the median of the two sorted arrays. The overall run time complexity should be O(log (m+n)).
Example 1:
nums1 = [1, 3]
nums2 = [2]
The median is 2.0
Example 2:
nums1 = [1, 2]
nums2 = [3, 4]
The median is (2 + 3)/2 = 2.5
分析
题目的要求是找出两个数组中的中位数。并且时间复杂度为O(log (m+n))。
观察时间复杂度,是O(log (m+n)),是一个相当优秀的算法的复杂度了。对于两个排好序的数组,要找出中位数,主要问题是在两个数组中找到一个数,假设这个数在第一个数组中大于m个数,在第二个数组中大于n个数,且满足以下关系:
//if nums1.size() + nums2.size() is odd
m + n == (nums1.size() + nums2.size()) / 2;
// if nums1.size() + nums2.size() is even, we only need to excute two times to get m1, n1 and m2, n2
m1 + n1 == (nums1.size() + nums2.size() / 2;
m2 + n2 == (nums2.size() + nums2.size() / 2 + 1;
因为是已经排好序的数组,所以找到一个数后可以马上知道m和n。
那么接下来就是找这个数的策略问题了,问题可以转化为怎么在只用O(log (m+n))次数就找到满足条件的数呢?
回想二分查找的时间复杂度,就是O(log n),二分查找通过每次舍弃一半的数可以在O(log n)的时间内找到一个数在数组中的位置。但是二分查找是针对一个数组而言的,那么问题就转化到如何在两个数组上应用二分查找了。如果简单地对两个数组合并排序,那么这样的复杂度显然是不能接受的。
因此,我们思考如何每次去掉一半的数呢。考虑这样的做法,每次找两个数组的中间的值,这个行为O(1)就可以做到,假设这两个数是mid1, mid2。
如果mid1 < mid2,可以知道,mid2 所在的数组中的大于mid2的数都不可能是两个数组的中位数,可以舍去;同理,如果mid1 > mid2,那么mid1所在的数组中的大于mid1的数都不可能是两个数组的中位数,可以舍去。而每个这样找数的过程都是一个缩小范围的子过程,因此可以使用分治的思想来解决。
同时,注意到,可以再将这样的思想泛化到寻找两个数组中第k大的数,而不再仅仅局限于中位数。
代码
int min(int a, int b) {
return (a < b)? a : b;
}
int getkth(int nums1[], int m, int nums2[], int n, int k) {
if (n > m)
return getkth(nums2, n, nums1, m, k);
if (n == 0)
return nums1[k - 1];
if (k == 1)
return min(nums1[0], nums2[0]);
int size2 = min(n, k / 2), size1 = k - size2;
if (nums1[size1 - 1] > nums2[size2 - 1])
return getkth(nums1, m, nums2 + size2, n - size2, k - size2);
if (nums1[size1 - 1] < nums2[size2 - 1])
return getkth(nums1 + size1, m - size1, nums2, n, k - size1);
return nums1[size1 - 1];
}
double findMedianSortedArrays(int* nums1, int nums1Size, int* nums2, int nums2Size) {
int sum_of_size = nums1Size + nums2Size;
if (sum_of_size % 2 != 0)
return (getkth(nums1, nums1Size, nums2, nums2Size, sum_of_size / 2 + 1)) * 1.0;
return (getkth(nums1, nums1Size, nums2, nums2Size, sum_of_size / 2) + getkth(nums1, nums1Size, nums2, nums2Size, sum_of_size / 2 + 1)) / 2.0;
}
运行结果
错误思路
第一次做这道题的时候,也是先想到了使用二分的思想来完成,但是在找数的过程中,没有做到上面算法的简洁,最终也只能达到O(log(m) * log(n)).
代码:
class Solution {
public:
//binary search to find how many nums smaller than the num
int binary_search(int begin, int end, vector<int>& nums, int num) {
//notice that the operator is <=
while (begin <= end) {
int mid = (begin + end) / 2;
if (nums[mid] > num) {
end = mid - 1;
} else if (nums[mid] < num) {
begin = mid + 1;
} else {
return mid;
}
}
return begin;
}
void swap(vector<int>& nums1, vector<int>& nums2) {
vector<int> temp = nums1;
nums1 = nums2;
nums2 = temp;
}
void swap(int& a, int& b) {
int temp = a;
a = b;
b = temp;
}
int findOneMedianSortedArrays(vector<int>& nums1, vector<int>& nums2, int location) {
//begin代表待搜索数组的起始,end代表待搜索数组的终止
//another_begin代表产生搜索数的数组的起始
int begin = 0, end = nums2.size() - 1, another_begin = 0, another_end = nums1.size() - 1, mid;
vector<int>& search_nums = nums2;
vector<int>& num_in_nums = nums1;
if (nums1.size() == 0) return nums2[location];
if (nums2.size() == 0) return nums1[location];
while (1) {
cout << "begin: " << begin << ' ' << "end: " << end << endl;
cout << "another_begin: " << another_begin << ' ' << "another_end: " << another_end << endl;
mid = (another_begin + another_end) / 2;
cout << "mid: " << mid << endl;
cout << "num_in_nums[mid]: " << num_in_nums[mid] << endl;
int count = binary_search(begin, end, search_nums, num_in_nums[mid]);
cout << "count: " << count << endl;
cout << endl;
if (count + mid == location) return num_in_nums[mid];
swap(search_nums, num_in_nums);
if (count + mid < location) {
another_begin = mid + 1;
begin = count;
swap(begin, another_begin);
swap(end, another_end);
} else {
another_end = mid - 1;
end = count - 1;
swap(begin, another_begin);
swap(end, another_end);
}
}
}
double findMedianSortedArrays(vector<int>& nums1, vector<int>& nums2) {
int sum_of_size = nums1.size() + nums2.size();
if (sum_of_size % 2) {
int median_location = sum_of_size / 2;
return (findOneMedianSortedArrays(nums1, nums2, median_location) * 1.0);
} else {
int median1_location = sum_of_size / 2 - 1;
int median2_location = sum_of_size / 2;
// return findOneMedianSortedArrays(nums1, nums2, median1_location);
return (findOneMedianSortedArrays(nums1, nums2, median1_location) + findOneMedianSortedArrays(nums1, nums2, median2_location)) / 2.0;
}
}
};