Give an O(log m + log n) algorithm to find the kth element in two sorted arrays A and B of size m and n respectively as if they were merged.
namespace details
{
template <typename I>
I kth_element( I b1, I e1, I b2, I e2, size_t k )
{
size_t n = std::distance( b1, e1 );
size_t m = std::distance( b2, e2 );
size_t last = n + m - 1;
assert(k > 0 && k < last);
bool found = false;
I ret;
int b = 0, e = static_cast< int >( n ), kth = static_cast< int >( k );
while ( b < e ) {
int i = ( b + e ) / 2;
int j = kth - i;
if ( j >= 0 && j < static_cast< int >( m ) ) {
if ( *( b1 + i ) <= *( b2 + j ) ) {
if ( j == 0 || *( b2 + j - 1 ) <= *( b1 + i ) ) {
found = true;
ret = b1 + i;
break;
} else {
b = i + 1;
}
} else {
if ( i == 0 || *( b1 + i - 1 ) <= *( b2 + j ) ) {
found = true;
ret = b2 + j;
break;
} else {
e = i;
}
}
} else {
if ( j < 0 ) {
e = i;
} else {
b = i + 1;
}
}
}
if ( !found ) {
//b must be bigger than 0, we can prove it.
assert( b > 0 );
if (b < static_cast<int>(n)) {
//if element b is in the middle of the array [b1..e1), then the element is bigger than any element in [b2..e2)
assert(*(b1 + b) >= *(b2 + m - 1));
int v = b + static_cast<int>(m);
ret = b1 + (b + kth - v);
}
else {
auto pos = std::upper_bound( b2, e2, *( b1 + b - 1 ) );
int v = b + static_cast<int>( std::distance( b2, pos ) );
ret = pos + ( kth - v );
}
}
return ret;
}
}
template <typename I>
I kth_element( I b1, I e1, I b2, I e2, size_t k )
{
I ret = e1;
size_t n = std::distance( b1, e1 );
size_t m = std::distance( b2, e2 );
if (n > 0 && m > 0) {
size_t last = n + m - 1;
if (k > 0 && k < last) {
ret = details::kth_element(b1, e1, b2, e2, k);
}
else {
if (k == 0) {
ret = *b1 <= *b2 ? b1 : b2;
}
else if (k == last) {
ret = *(b1 + n - 1) >= *(b2 + m - 1) ? b1 + n - 1 : b2 + m - 1;
}
}
}
else {
if (n > 0) {
if (k < n) {
ret = b1 + k;
}
}
else if (m > 0) {
ret = e2;
if (k < m) {
ret = b2 + k;
}
}
}
return ret;
}
The testing code:
void verify_kth_element_in_two_sorted_arrays()
{
size_t primes[] = { 1, 2, 3, 5, 7, 11, 13, 17, 19, 23, 29, 31, 37, 41, 43, 47, 53, 59, 61, 67, 71, 73, 79, 83, 89, 97, 101, 103, 107, 109, 113, 127, 131, 137, 139, 149, 151, 157, 163, 167, 173, 179, 181, 191, 193, 197, 199, 211, 223, 227, 229, 233, 239, 241, 251, 257, 263, 269, 271, 277, 281, 283, 293 };
for (size_t i = 0; i < _countof(primes); ++i) {
std::vector<int> d1(random_sorted_vector(primes[i]));
for (size_t j = 0; j < _countof(primes); ++j) {
std::vector<int> d2( random_sorted_vector( primes[ j ] ) );
auto previous = v1::kth_element(std::begin(d1), std::end(d1), std::begin(d2), std::end(d2), 0);
for (size_t k = 1, ke = primes[i] + primes[j]; k < ke; ++k) {
auto current = v1::kth_element( std::begin( d1 ), std::end( d1 ), std::begin( d2 ), std::end( d2 ), k );
//verify( *previous <= *current, stderr, "Failed to find the %uth element from the two sorted array\n", k - 1 );
if ( *previous > *current ) {
std::for_each( d1.begin(), d1.end(), []( int n ) { printf( "%d ", n ); } ); printf( "\n" );
std::for_each( d2.begin(), d2.end(), []( int n ) { printf( "%d ", n ); } ); printf( "\n" );
printf("n = %u, m = %u, k = %u\n", primes[i], primes[j], k);
debug_break();
}
previous = current;
}
//printf( "n = %u, m = %u case PASS\n", primes[ i ], primes[ j ] );
}
}
}