【经典算法题】查找和最小的K对数字
Leetcode 0373 查找和最小的K对数字
分析
-
本题的考点:多路归并、堆。
-
本题一共有 n × m n \times m n×m对数据,将数组
nums1、nums2
分别记为a、b
,按照数组b
可以将这些数据分为m
路,且每一路都是升序的,如下图:
- 使用小根堆求解前
k
小的数据即可。
代码
- C++
typedef vector<int> VI; // (a[i] + b[j], i, j)
class Solution {
public:
vector<vector<int>> kSmallestPairs(vector<int> &a, vector<int> &b, int k) {
if (a.empty() || b.empty()) return {};
int n = a.size(), m = b.size();
priority_queue<VI, vector<VI>, greater<VI>> heap;
for (int i = 0; i < m; i++) heap.push({a[0] + b[i], 0, i});
vector<VI> res;
while (k-- && heap.size()) {
auto t = heap.top(); heap.pop();
res.push_back({a[t[1]], b[t[2]]});
if (t[1] + 1 < n)
heap.push({a[t[1] + 1] + b[t[2]], t[1] + 1, t[2]});
}
return res;
}
};
- Java
class Solution {
public List<List<Integer>> kSmallestPairs(int[] a, int[] b, int k) {
int n = a.length, m = b.length;
if (n == 0 || m == 0) return new ArrayList<>();
PriorityQueue<int[]> heap = new PriorityQueue<>((o1, o2) -> { return o1[0] - o2[0]; });
for (int i = 0; i < m; i++) heap.add(new int[]{a[0] + b[i], 0, i});
List<List<Integer>> res = new ArrayList<>();
while (k-- != 0 && !heap.isEmpty()) {
int[] t = heap.remove();
res.add(Arrays.asList(a[t[1]], b[t[2]]));
if (t[1] + 1 < n)
heap.add(new int[]{a[t[1] + 1] + b[t[2]], t[1] + 1, t[2]});
}
return res;
}
}
时空复杂度分析
-
时间复杂度: O ( k × l o g ( n ) ) O(k \times log(n)) O(k×log(n)),
n
为数组长度。 -
空间复杂度: O ( n ) O(n) O(n)。
扩展题目
AcWing 146. 序列
问题描述()
-
问题链接:AcWing 146. 序列
分析
-
一共
m
个序列,每个序列都是n
个数据,每个序列中选出一个数,求和。我们需要找到最小的n
个。 -
考虑前两个序列,两个序列可以构成 n 2 n^2 n2 个数据,我们找到前
n
个最小的数据即可,后面的数据一定用不到。因此对于这m
个序列,每次合并两个序列得到一个新的序列,直到合并到只有一个序列就是最后的答案。 -
考虑如何合并两个序列。对于数组
a
和数组b
,我们可以让数组a
升序排列,然后构造n
路,使用n
路归并排序即可。如下图:
- 本题的一个简化版本:Leetcode 0373 查找和最小的K对数字。
代码
- C++
#include <iostream>
#include <cstring>
#include <algorithm>
#include <queue>
#define x first
#define y second
using namespace std;
typedef pair<int, int> PII;
const int N = 2010;
int m, n;
int a[N], b[N];
int c[N]; // 多路归并过程中使用到的数组
void merge() {
priority_queue<PII, vector<PII>, greater<PII>> heap;
for (int i = 0; i < n; i++) heap.push({b[i] + a[0], 0});
for (int i = 0; i < n; i++) {
auto t = heap.top();
heap.pop();
c[i] = t.x;
heap.push({t.x - a[t.y] + a[t.y + 1], t.y + 1});
}
memcpy(a, c, sizeof a);
}
int main() {
int T;
scanf("%d", &T);
while (T--) {
scanf("%d%d", &m, &n);
for (int i = 0; i < n; i++) scanf("%d", &a[i]);
sort(a, a + n);
for (int i = 0; i < m - 1; i++) {
for (int j = 0; j < n; j++) scanf("%d", &b[j]);
merge();
}
for (int i = 0; i < n; i++) printf("%d ", a[i]);
puts("");
}
return 0;
}