题目不难,关键是边界条件要想清楚。先写一个时间复杂度为O(K) 的解法。


#include <iostream>
using namespace std;

//a[] increase
//b[] decrease
//use ret_value to return the result
//function ret reprsent the error if not 0
int find_k(int a[], int b[], int m, int n, int k, int& ret_value) {
    if(k<=0 || m<0 || n<0 || k>(m+n)) {
        return -1;
    }

    int i = 0, x = m-1, y = 0;
    while (i<k && x>=0 && y<=n-1) {
        if (a[x] > b[y]) {
            ret_value = a[x];
        i++;
        x--;
    } else if (a[x] < b[y]) {
            ret_value = b[y];
        i++;
            y++;
    } else { //equal
            ret_value = a[x];
        i+=2;
        x--;
        y++;
    }
    }

    cout << "x:" << x << endl 
         << "y:" << y << endl
         << "i:" << i << endl
         << "v:" << ret_value << endl;

    i = k-i;
    if (i > 0) {
        if (x<0) {
            ret_value = a[y+i-1];
        } else if (y > n-1) {
            ret_value = a[x-i+1];
        }
    }

    return 0;
}

int main () {
   int a[] = {1,2,3};
   int b[] = {6,5,3};
   int k=0;

   cout << "please input the k:";
   cin >> k;

   int value = -1;
   find_k(a,b,3,3,k,value);
   cout << value << endl;
   
   return 0;
}


再附上一个时间复杂服为O(log n)的算法实现,

算法思想是:

每次从两个数组里数 k/2 个数,比较其大小,较大的那个数所在的数组里这 k/2 个数, 肯定在前k大个数里面;

这样就排除了k/2个数,然后再在剩下的数里面找第k/2大的数,循环到找到最后一个数为止。

 

为了简化一下,假设两个数组都为降序。

注意程序的31~34行,如果不判断相等,当最后还剩一个数的时候,第20行

mid_b = pb + kb -1;
mid_b 会被赋值为一个错误的位置,导致程序陷入死循环
//both a[] b[] decrease
int find_k_both_increase_O_log_k(int a[], int b[], int m, int n, int k, int& ret_value) {
    if(k<=0 || m<0 || n<0 || k>(m+n)) {
        return -1;
    }

    int ka = 0;
    int kb = 0;
    int pa = 0;
    int pb = 0;
    int mid_a = 0;
    int mid_b = 0;
    int i = k;

    while (i>0 && pa<m && pb<n) {
        ka = i/2;
        kb = i-ka;

        mid_a = pa + ka -1;
        mid_b = pb + kb -1;

        if (mid_a<m && mid_b<n) {
            if (a[mid_a] > b[mid_b]) {
                ret_value = a[mid_a];
                pa = mid_a + 1;
                i = i - ka;
            } else if (a[mid_a] < b[mid_b]) {
                ret_value = b[mid_b];
                pb = mid_b + 1;
                i = i - kb;
            } else { //equal
                ret_value = a[mid_a];
                return 0;
            }
        }
        
        cout << "pa:" << pa << endl
             << "pb:" << pb << endl
             << "i:" << i << endl;
        cout <<endl;
    }

    if (i > 0) {
        if (pa > m-1) {
            ret_value = a[pb-1+i];
        } else if (pb > n-1) {
            ret_value = a[pa-1+i];
        }
    }

    return 0;
}


在leetcode上碰到该题的改进版本,被标识为hard,看起来简单的题目,其实不简单。

附上题目和代码如下:

Median of Two Sorted Arrays

There are two sorted arrays A and B of size m and n respectively. Find the median of the two sorted arrays. The overall run time complexity should be O(log (m+n)).

找第k大个数,和我上面一个稍微有点不同,写起来很多边界条件,回头多看几遍。

int find(int A[], int m, int B[], int n, int k) {
    if (m==0) {
        return B[k-1];
    }
    
    if (n==0) {
        return A[k-1];
    }
    
    
    int pa=0, pb=0;
    int sa=0, sb=0;
    int ka=0, kb=0;
    
    while (k>1 && sa<m && sb<n) {
        ka = k/2;
        kb = k-ka;
        pa = sa + ka -1;
        pb = sb + kb -1;  
        if (pa >= m) {
            pa = m-1;
            ka = m-sa;
            
            kb = k-ka;
            pb = sb + kb -1;
        } else if (pb >= n) {
            pb = n-1;
            kb = n-sb;
            
            ka = k - kb;
            pa = sa + ka -1;
        }
        
        if (A[pa] > B[pb]) {
            k = ka;
            sb = pb+1;
        } else if (A[pa] < B[pb]) {
            k = kb;
            sa = pa+1;
        } else { //A[pa] == B[pb]
            return A[pa];
        }
    }
    
    if (sa>=m) {
        return B[sb+k-1];
    } else if (sb >= n) {
        return A[sa+k-1];
    }

    if (k == 1) {
        return A[sa] < B[sb] ? A[sa] : B[sb];
    }
    
}



class Solution {
public:
    double findMedianSortedArrays(int A[], int m, int B[], int n) {

        int k=(m+n+1)/2;
        
        if ((m+n)&1) {
            return find(A, m, B, n, k);
        } else {
            return (double)(find(A, m, B, n, k) + find(A, m, B, n, k+1))/2;
        }  
    }
};