Divide and Conquer, 每次将问题规模缩小到 K*(N-1)/N的大小。
时间O(logN)
注意几个corner case:
1. k > (n1+n2+...+nN)
2. 只剩下一个array。
3. k < N时,从每个数组中取出前k个,然后再取最小的k个。
#include <iostream>
#include <vector>
#include <algorithm>
#include <cstdlib>
#include <climits>
using namespace std;
bool comp(vector<int> a1, vector<int> a2)
{
return a1.size() < a2.size();
}
int find_kth(vector<vector<int> > array, int k)
{
for(int i=0; i<array.size(); i++)
if(array[i].size() == 0)
array.erase(array.begin()+i);
int n = array.size();
std::sort(array.begin(), array.end(), comp);
if(n == 1)
return array[n-1][k-1];
if(k < n)
{
vector<int> result;
for(int i=0; i<n; i++)
{
for(int j=0; j<k && j<array[i].size(); j++)
result.push_back(array[i][j]);
}
std::sort(result.begin(), result.end());
return result[k-1];
}
int num = k/n;
int sum = 0;
vector<int> nums(n);
for(int i=0; i<n; i++)
{
nums[i] = min(num, (int)(array[i].size()));
sum += nums[i];
if(i < n-1) num = (k-sum)/(n-i-1);
}
if(sum < k)
{
nums[n-1] = k - (sum-nums[n-1]);
}
int min = INT_MAX, min_idx = -1;
for(int i=0; i<n; i++)
{
if(nums[i] >= 1 && array[i][nums[i]-1] < min)
{
min = array[i][nums[i]-1];
min_idx = i;
}
}
array[min_idx].erase(array[min_idx].begin(), array[min_idx].begin()+nums[min_idx]);
return find_kth(array, k-nums[min_idx]);
}
int find_kth2(vector<int> A, int m, vector<int> B, int n, int k)
{
if(m > n)
return find_kth2(B, n, A, m, k);
if(m == 0)
return B[k-1];
if(k == 1)
return min(A[0], B[0]);
int a = min(m, k/2);
int b = k - a;
if(A[a-1] < B[b-1])
{
vector<int> AA(A.begin()+a, A.end());
return find_kth2(AA, m-a, B, n, k-a);
}
else if(A[a-1] > B[b-1])
{
vector<int> BB(B.begin()+b, B.end());
return find_kth2(A, m, BB, n-b, k-b);
}
else
return A[a-1];
}
int main()
{
vector<int> v1(100);
vector<int> v2(200);
vector<int> v3(300);
vector<int> v4(400);
vector<int> v5(500);
vector<int> merged(1500);
vector<vector<int> > array(5);
array[0] = v1;
array[1] = v2;
array[2] = v3;
array[3] = v4;
array[4] = v5;
int count = 0;
for(int i=0; i<array.size(); i++)
{
for(int j=0; j<array[i].size(); j++)
{
int rad = rand()%100000;
array[i][j] = rad;
merged[count] = rad;
count++;
}
std::sort(array[i].begin(), array[i].end());
}
std::sort(merged.begin(), merged.end());
/*
for(int k=0; k<array.size(); k++)
{
for(int i=0; i<array[k].size(); i++)
cout << array[k][i] << " ";
cout << endl;
}*/
//cout << find_kth(array, 1333) << endl;
int t;
for(int i=1; i<1500; i++)
if(merged[i-1] != (t = find_kth(array, i)))
cout << "error on " << i << " " << merged[i-1] << " " << t << endl;
return 0;
}