http://pat.zju.edu.cn/contests/pat-a-practise/1029
二分搜索median。
代码一: 时间复杂度log(m + n)。
思路:比较两个数组的中间元素,然后可以排除掉其中一个数组的一半。
假设数组A, B, 我们要找的是A, B合并为一个数组后的第median个元素。令A, B中间元素分别为A[mid1], B[mid2]:
如果 A[mid1] == B[mid2],则结果就为A[mid1];
如果 A[mid1] < B[mid2],这时B[mid2]的前面一共有 mid1 + mid2 + 1个数字,
此时需要判断mid1 + mid2 + 1和median 的大小:
如果前者大,意味着我们要找的数字就在这 mid1 + mid2 + 1个数字中,从而我们可以排除B数组中的后一半;
如果后者大,意味着我们要找的数字不在这 mid1 + mid2 + 1个数字中,从而我们可以排除A 数组中的前一半。
当A[mid1] > B[mid2]时情况类似。
#include <iostream>
#include <cstdio>
#include <vector>
#include <cstring>
#include <algorithm>
#include <stack>
using namespace std;
long long *a, *b;
long long findMedian(long long *v1, long long *v2, int len1, int len2, int indexOfMid)
{
if(len1 == 0) return v2[indexOfMid - 1];
if(len2 == 0) return v1[indexOfMid - 1];
int mid1 = (len1 - 1) >> 1, mid2 = (len2 - 1) >> 1;
int index = mid1 + mid2 + 1;
if(v1[mid1] == v2[mid2]) return v1[mid1];
if(v1[mid1] < v2[mid2])
{
if(index < indexOfMid) return findMedian(v1 + mid1 + 1, v2, len1 - mid1 - 1, len2, indexOfMid - mid1 - 1);
else return findMedian(v1, v2, len1, mid2, indexOfMid);
}
else
{
if(index < indexOfMid) return findMedian(v1, v2 + mid2 + 1, len1, len2 - mid2 - 1, indexOfMid - mid2 - 1);
return findMedian(v1, v2, mid1, len2, indexOfMid);
}
}
int main()
{
int m, n, i;
long long x;
scanf("%d", &m);
a = new long long[m + 10];
for(i = 0; i < m; i ++)
{
scanf("%lld", &a[i]);
}
scanf("%d", &n);
b = new long long[n + 10];
for(i = 0; i < n; i ++)
{
scanf("%lld", &b[i]);
}
int mid = (m + n + 1) >> 1;
long long ans = findMedian(a, b, m, n, mid);
printf("%lld\n", ans);
delete a;
delete b;
return 0;
}
代码二:时间复杂度logm*logn(想复杂了)
#include <iostream>
#include <cstdio>
#include <vector>
#include <cstring>
#include <algorithm>
#include <stack>
using namespace std;
long long *a, *b;
int binaryS1(long long *v, long long target, int size) // find first smaller than target
{
int start = 0, end = size - 1;
while(start < end)
{
int mid = (start + end) >> 1;
if(v[mid] >= target) end = mid - 1;
else start = mid + 1;
}
if(v[start] >= target) start -- ;
return start;
}
int binaryS2(long long *v, long long target, int size) // find first larger than target
{
int start = 0, end = size - 1;
while(start < end)
{
int mid = (start + end) >> 1;
if(v[mid] <= target) start = mid + 1;
else end = mid - 1;
}
if(v[start] <= target) start ++ ;
return start;
}
bool isFind = false;
long long findMedian(long long *v1, long long *v2, int size1, int size2)
{
int start = 0, end = size1 - 1;
int median = (size1 + size2) >> 1;
if((size1 + size2) % 2 == 1) median ++ ;
while(start <= end)
{
int mid = (start + end) >> 1;
int index1 = binaryS1(v2, v1[mid], size2);
int index2 = binaryS2(v2, v1[mid], size2);
int posMin = mid + 1 + index1 + 1;
int posMax = mid + 1 + index2;
if(posMin <= median && posMax >= median)
{
isFind = true;
return v1[mid];
}
else if(posMin > median) end = mid - 1;
else start = mid + 1;
}
return -1;
}
int main()
{
int m, n, i;
long long x;
scanf("%d", &m);
a = new long long[m + 10];
for(i = 0; i < m; i ++)
{
scanf("%lld", &a[i]);
}
scanf("%d", &n);
b = new long long[n + 10];
for(i = 0; i < n; i ++)
{
scanf("%lld", &b[i]);
}
long long ans = findMedian(a, b, m, n);
if(isFind == false) ans = findMedian(b, a, n, m);
printf("%lld\n", ans);
delete a;
delete b;
return 0;
}
当然这题还有O(m + n)的算法,即归并排序思想。