题意
给定两个有序数组A和B,长度分别为m, n。
求这两个数组合并之后的第k小的数。
解答
int find_kth(const int *a, int m, const int *b, int n, int k){
const int *afrom = a, *amid, *ato = a + m;
const int *bfrom = b, *bmid, *bto = b + n;
int aempty = (!a || m <= 0), bempty = (!b || n <= 0);
int amax = aempty ? INT_MIN : a[m-1], amin = aempty ? INT_MAX : a[0];
int bmax = bempty ? INT_MIN : b[n-1], bmin = bempty ? INT_MAX: b[0];
int vmin = amin < bmin ? amin : bmin;
int vmax = amax > bmax ? amax : bmax;
int mid, kvalue;
if (aempty && bempty) return INT_MAX;
if (k <= 0) return vmin;
if (k >= (m + n)) return vmax;
while (afrom < ato && bfrom < bto) {
amid = afrom + ((ato - afrom) >> 1); bmid = bfrom + ((bto - bfrom) >> 1);
mid = amid - afrom + bmid - bfrom;
if (0 == k) return *afrom < *bfrom ? *afrom : *bfrom;
if (k >= ((ato - afrom) + (bto - bfrom) - 1))
return *(ato-1) > *(bto-1) ? *(ato-1) : *(bto-1);
assert(0 < k && k < ((ato - afrom) + (bto - bfrom)));
if (*amid == *bmid) {
// A: [afrom amid), amid, [amid + 1, ato)
// B: [bfrom, bmid),amid, [bmid + 1, bto)
if (k < mid) {
ato = amid; bto = bmid;
} else if (k > (mid + 1)) {
afrom = amid + 1; bfrom = bmid + 1;
k -= mid + 2;
} else if (k == mid || k == (mid + 1)) {
return *amid;
}
} else if (*amid > *bmid) {
// A: [afrom, amid), [amid, ato)
// B: [bfrom, bmid), bmid, [bmid + 1, bto)
if (k < mid) {
// When *amid > *bmid. *amid is larger than mid+1 numbers.
// Delete [amid, ato)
ato = amid;
} else if (k >= mid) {
// Delete [bform, bmid).
if (bmid == bfrom) {
kvalue = afrom[k];
if (kvalue <= *bmid) {
return kvalue;
} else {
ato = afrom + k;
}
} else {
k -= bmid - bfrom;
bfrom = bmid;
}
}
} else if (*amid < *bmid) {
if (k < mid) {
bto = bmid;
} else if (k >= mid) {
if (amid == afrom) {
kvalue = bfrom[k];
if (kvalue <= *amid) return kvalue;
else bto = bfrom + k;
} else {
k -= amid - afrom;
afrom = amid;
}
}
}
}
if (bfrom >= bto) {
return afrom[k];
}
if (afrom >= ato) {
return bfrom[k];
}
}
下面的测试代码部分:test()利用随机函数来生成各种各样的数组,并与排序后的结果进行验证。
测试代码:
#include <limits.h>
#include <stdlib.h>
#include <stdio.h>
#include <assert.h>
int find_kth(const int *a, int m, const int *b, int n, int k){
//值得注意的是k的取值是从0开始的,也就是说k = 0代表了两者最小的数。
const int *afrom = a, *amid, *ato = a + m;
const int *bfrom = b, *bmid, *bto = b + n;
int aempty = (!a || m <= 0), bempty = (!b || n <= 0);
int amax = aempty ? INT_MIN : a[m-1], amin = aempty ? INT_MAX : a[0];
int bmax = bempty ? INT_MIN : b[n-1], bmin = bempty ? INT_MAX: b[0];
int vmin = amin < bmin ? amin : bmin;
int vmax = amax > bmax ? amax : bmax;
int mid, kvalue;
// 注意:当k的取值不合法时,利用vmin, vmax来设置返回值,
// 直接利用INT_MIN, INT_MAX来避免数组是否为空的讨论。
if (aempty && bempty) return INT_MAX;
if (k <= 0) return vmin;
if (k >= (m + n)) return vmax;
// 经过上面的处理之后,肯定能够保证0 < k < total_length。
// 因此,0 < k < total_length是属于恒等不变式。
while (afrom < ato && bfrom < bto) {
amid = afrom + ((ato - afrom) >> 1); bmid = bfrom + ((bto - bfrom) >> 1);
mid = amid - afrom + bmid - bfrom;
// 确保0 < k < (total_length)成立
// 当k的取值在边界上时,直接设置返回值。
if (0 == k) return *afrom < *bfrom ? *afrom : *bfrom;
if (k >= ((ato - afrom) + (bto - bfrom) - 1))
return *(ato-1) > *(bto-1) ? *(ato-1) : *(bto-1);
assert(0 < k && k < ((ato - afrom) + (bto - bfrom)));
if (*amid == *bmid) {
// 注意:当两数相等时,这两个数组的切分方式如下:
// amid和bmid均属于独立部分。
// A: [afrom, amid), amid, [amid + 1, ato)
// B: [bfrom, bmid), bmid, [bmid + 1, bto)
if (k < mid) {
// 如果k小于前面部分的长度,直接去掉[amid, ato)及[bmid, bto)
// 这两部分。
ato = amid; bto = bmid;
} else if (k > (mid + 1)) {
// 如果k > (mid+1) 那么,要寻找的值必然是属于后半部分。
// 去掉[afrom, amid], [bfrom, bmid]
// 注意减去的长度为mid + 2.
afrom = amid + 1; bfrom = bmid + 1;
k -= mid + 2;
} else if (k == mid || k == (mid + 1)) {
// 如果k == mid,或者k == mid + 1,此时目标值为*amid, *bmid。
return *amid;
}
} else if (*amid > *bmid) {
// A: [afrom, amid), [amid, ato)
// B: [bfrom, bmid), bmid, [bmid + 1, bto)
if (k < mid) {
// When *amid > *bmid. *amid is larger than mid+1 numbers.
// Delete [amid, ato)
ato = amid;
} else if (k >= mid) {
// Delete [bform, bmid).
if (bmid == bfrom) {
//注意:在删除前端的时候,需要对是否删除的长度为0进行讨论。
// 在删除前端长度为0的时候,这种情况一般都是一个数组
// 长度为1,而另外一个数组长度正常的情况下发生的。
// 比如: A = [34], B = [0, 24, 34, 45];
kvalue = afrom[k];
if (kvalue <= *bmid) {
// 如果单一值在k值之外,直接返回相应值。
// 这里不需要再做0 <= k或者k >=length的判断。
// 在前面已经有了不变式的保证。
return kvalue;
} else {
// 当此单一值在k值之内的时候。直接可以舍弃多余的部分。
ato = afrom + k;
}
} else {
// 删除前端的部分不是为空,那么直接删除即可。
k -= bmid - bfrom;
bfrom = bmid;
}
}
} else if (*amid < *bmid) {
if (k < mid) {
bto = bmid;
} else if (k >= mid) {
if (amid == afrom) {
kvalue = bfrom[k];
if (kvalue <= *amid) return kvalue;
else bto = bfrom + k;
} else {
k -= amid - afrom;
afrom = amid;
}
}
}
}
//处理一个数组为空,而另外一个数组非空的情况。
//这里不需要再讨论不合法的k值的情况。因为在前面已经讨论。
if (bfrom >= bto) {
return afrom[k];
}
if (afrom >= ato) {
return bfrom[k];
}
}
int cmp(const void *a, const void *b) {
return (*(int *)a) - (*(int *)b);
}
void aprint(int *a, int n) {
int i = 0;
for (i = 0; i < n; ++i) {
printf("%d, ", a[i]);
}
printf("\n");
}
void test(void) {
int *a = NULL, *b = NULL, *c = NULL;
int m = 0, n = 0;
int iter = 0, i, ret, cret, find_error = 1;
for (iter = 0; iter < 100 && find_error; ++iter) {
m = rand() % 1000; n = rand() % 1000;
a = (int *)malloc(sizeof(int) * m);
b = (int *)malloc(sizeof(int) * n);
c = (int *)malloc(sizeof(int)* (m+n));
for (i = 0; i < m; ++i) { a[i] = rand() % 100; c[i] = a[i];}
for (i = 0; i < n; ++i) { b[i] = rand() % 100; c[m+i] = b[i];}
qsort(a, m, sizeof(int), cmp);
qsort(b, n, sizeof(int), cmp);
qsort(c, m + n, sizeof(int), cmp);
for (i = -2; i < m + n + 10; ++i) {
if (i <= 0) cret = c[0];
else if (i >= (m + n)) cret = c[m+n-1];
else cret = c[i];
ret = find_kth(a, m, b, n, i);
if (ret != cret) {
printf("Error i = %d, ret = %d, cret = %d\n", i, ret, cret);
printf("a = "); aprint(a, m);
printf("b = "); aprint(b, n);
printf("c = "); aprint(c, m + n);
ret = find_kth(a, m, b, n, i);
find_error = 0;
break;
}
}
free(a); free(b); free(c);
}
for (iter = 0; iter < 100 && find_error; ++iter) {
m = 1; n = rand() % 1000;
a = (int *)malloc(sizeof(int) * m);
b = (int *)malloc(sizeof(int) * n);
c = (int *)malloc(sizeof(int)* (m+n));
for (i = 0; i < m; ++i) { a[i] = rand() % 100; c[i] = a[i];}
for (i = 0; i < n; ++i) { b[i] = rand() % 100; c[m+i] = b[i];}
qsort(a, m, sizeof(int), cmp);
qsort(b, n, sizeof(int), cmp);
qsort(c, m + n, sizeof(int), cmp);
for (i = -2; i < m + n + 10; ++i) {
if (i <= 0) cret = c[0];
else if (i >= (m + n)) cret = c[m+n-1];
else cret = c[i];
ret = find_kth(a, m, b, n, i);
if (ret != cret) {
printf("Error i = %d, ret = %d, cret = %d\n", i, ret, cret);
printf("a = "); aprint(a, m);
printf("b = "); aprint(b, n);
printf("c = "); aprint(c, m + n);
ret = find_kth(a, m, b, n, i);
find_error = 0;
break;
}
}
free(a); free(b); free(c);
}
for (iter = 0; iter < 100 && find_error; ++iter) {
m = 1; n = 1;
a = (int *)malloc(sizeof(int) * m);
b = (int *)malloc(sizeof(int) * n);
c = (int *)malloc(sizeof(int)* (m+n));
for (i = 0; i < m; ++i) { a[i] = rand() % 100; c[i] = a[i];}
for (i = 0; i < n; ++i) { b[i] = rand() % 100; c[m+i] = b[i];}
qsort(a, m, sizeof(int), cmp);
qsort(b, n, sizeof(int), cmp);
qsort(c, m + n, sizeof(int), cmp);
for (i = -2; i < m + n + 10; ++i) {
if (i <= 0) cret = c[0];
else if (i >= (m + n)) cret = c[m+n-1];
else cret = c[i];
ret = find_kth(a, m, b, n, i);
if (ret != cret) {
printf("Error i = %d, ret = %d, cret = %d\n", i, ret, cret);
printf("a = "); aprint(a, m);
printf("b = "); aprint(b, n);
printf("c = "); aprint(c, m + n);
ret = find_kth(a, m, b, n, i);
find_error = 0;
break;
}
}
free(a); free(b); free(c);
}
for (iter = 0; iter < 100 && find_error; ++iter) {
m = rand() % 1000; n = 1;
a = (int *)malloc(sizeof(int) * m);
b = (int *)malloc(sizeof(int) * n);
c = (int *)malloc(sizeof(int)* (m+n));
for (i = 0; i < m; ++i) { a[i] = rand() % 100; c[i] = a[i];}
for (i = 0; i < n; ++i) { b[i] = rand() % 100; c[m+i] = b[i];}
qsort(a, m, sizeof(int), cmp);
qsort(b, n, sizeof(int), cmp);
qsort(c, m + n, sizeof(int), cmp);
for (i = -2; i < m + n + 10; ++i) {
if (i <= 0) cret = c[0];
else if (i >= (m + n)) cret = c[m+n-1];
else cret = c[i];
ret = find_kth(a, m, b, n, i);
if (ret != cret) {
printf("Error i = %d, ret = %d, cret = %d\n", i, ret, cret);
printf("a = "); aprint(a, m);
printf("b = "); aprint(b, n);
printf("c = "); aprint(c, m + n);
ret = find_kth(a, m, b, n, i);
find_error = 0;
break;
}
}
free(a); free(b); free(c);
}
for (iter = 0; iter < 100 && find_error; ++iter) {
m = rand() % 1000; n = m;
a = (int *)malloc(sizeof(int) * m);
b = (int *)malloc(sizeof(int) * n);
c = (int *)malloc(sizeof(int)* (m+n));
for (i = 0; i < m; ++i) { a[i] = rand() % 100; c[i] = a[i];}
for (i = 0; i < n; ++i) { b[i] = a[i]; c[m+i] = b[i];}
qsort(a, m, sizeof(int), cmp);
qsort(b, n, sizeof(int), cmp);
qsort(c, m + n, sizeof(int), cmp);
for (i = -2; i < m + n + 10; ++i) {
if (i <= 0) cret = c[0];
else if (i >= (m + n)) cret = c[m+n-1];
else cret = c[i];
ret = find_kth(a, m, b, n, i);
if (ret != cret) {
printf("Error i = %d, ret = %d, cret = %d\n", i, ret, cret);
printf("a = "); aprint(a, m);
printf("b = "); aprint(b, n);
printf("c = "); aprint(c, m + n);
ret = find_kth(a, m, b, n, i);
find_error = 0;
break;
}
}
free(a); free(b); free(c);
}
}
int main(void) {
int a[] = {1, 2, 3, 4, 5, 6};
int m = sizeof(a)/sizeof(int);
int *b = a;
int n = m;
int x;
test();
while (scanf("%d", &x) != EOF) {
printf("%d\n", find_kth(a, m, b, n, x));
}
return 0;
}