668. Kth Smallest Number in Multiplication Table
Nearly every one have used the Multiplication Table. But could you find out the k-th smallest number quickly from the multiplication table?
Given the height m and the length n of a m * n Multiplication Table, and a positive integer k, you need to return the k-th smallest number in this table.
Example1
Input: m = 3, n = 3, k = 5
Output: 3
Explanation:
The Multiplication Table:
1 2 3
2 4 6
3 6 9
The 5-th smallest number is 3 (1, 2, 2, 3, 3).
Example2
Input: m = 2, n = 3, k = 6
Output:
Explanation:
The Multiplication Table:
1 2 3
2 4 6
The 6-th smallest number is 6 (1, 2, 2, 3, 4, 6).
Note:
The m and n will be in the range [1, 30000].
The k will be in the range [1, m * n]
思路
题目要求我们在一个n * m
的乘法表里找第K小的数字,一开始我的想法是用BFS,因为乘法表有个特点,右边的数都比左边的大,下面的数要比上面的大,如果一开始把左上角的起点放进数组,再以某种顺序遍历其他节点的话,把他们按从小到大的顺序放进队列里,就能找出第K小的数了。
后来我发现BFS不太可行,虽然相邻节点之间的大小关系是确定,但当n个m足够大时,在乘法表右下角总能找到一个数,比已经放在BFS队列里的数值要小,这样的话就不能根据队列的长度判断当前数字是否是第K小了。我试着加入一些判断条件,改变节点的遍历顺序,然而这不仅使BFS变得繁琐,而且大多是错误的,求不出正解。
最后我放弃骚操作,决定老老实实地查找第K小的数,考虑到m、n的取值范围都比较大,一个一个地去试的话时间复杂度大概是O((m*n)^2)
,所以我才用了二分查找。对于每个数,都要扫描整个乘法表查找有多少个比它小的数,因此时间复杂度变为O((m*n)log(m*n))
。
我很快就写好了初版代码,毕竟也不是很长,除了二分查找的部分基本没啥麻烦的地方。提交上LeetCode发现Time Limit Exceeded
了。如果是超时的话,那大概就是扫描乘法表查找目前的数字是第几小的部分才会导致这个问题,这部分的代码我是这样写的:
int count = 0;
for(int i = 1; i <= n; i++) {
for (int j = 1; j <= m; j++){
if (i * j <= mid) count++;
else break;
}
}
对于每一行,我都从头开始往后找,当出现i*j>mid
时就可以结束这一行的遍历了,因为后面的数都比mid大。既然从头往后找会超时,我就从后往前找嘛,当K足够大时,理论上从后往前找的方法需要循环的次数的确会少一点。于是有了第二版代码:
int count = 0;
for(int i = 1; i <= n; i++) {
int j = 0;
for (j = m; j > 0; j--){
if (i * j <= mid) break;
}
count += j;
}
然而事实证明这并没有快多少,该超时的测试样例还是会超时,那么就要进一步剪枝,查找更少的节点来得到答案。这时我注意到了,一开始我思考BFS的解法时就发现右边的数总是比左边的大,下面的数总是比上面的大,这也就意味着,假如第i行的第j个数比mid大,那么第i+1行第一个比mid的数所在的列肯定小于等于j。最后经过这么多波折,我终于把这道看着简单的题目做出来了,它的难点在于对代码运行的时间要求太苛刻了,仅仅用二分查找是满足不了的,还要结合乘法表的特点进行剪枝。最终版的代码在下面给出。
最终代码
class Solution{
public:
int findKthNumber(int m, int n, int k) {
int head = 1;
int tail = m * n;
while(head < tail) {
int mid = head + (tail - head) / 2;
int count = 0;
int j = m;
for(int i = 1; i <= n; i++) {
for (j; j > 0; j--){
if (i * j <= mid) break;
}
count += j;
}
if(count < k) head = mid + 1;
else tail = mid;
}
return head;
}
};