算法-堆/多路归并-查找和最小的 K 对数字
1 题目概述
1.1 题目出处
https://leetcode.cn/problems/find-k-pairs-with-smallest-sums/description/?envType=study-plan-v2&envId=top-interview-150
1.2 题目描述
2 优先级队列构建大顶堆
2.1 思路
将两个数字的和放入大顶堆中,堆的最大大小为k。
当堆大小小于k时,直接放入堆。
当堆大小达到k后,比较当前元素和堆顶的元素,如果比堆顶元素小,就移除堆顶元素并放入当前元素。
最后,堆内元素就是和最小的K对数。
2.2 代码
class Solution {
List<List<Integer>> resultList = new LinkedList<>();
public List<List<Integer>> kSmallestPairs(int[] nums1, int[] nums2, int k) {
// 大顶堆
PriorityQueue<List<Integer>> queue = new PriorityQueue<>((o1, o2)->o2.get(0) + o2.get(1) - o1.get(0) - o1.get(1));
for (int i = 0; i < Math.min(nums1.length, k); i++) {
for (int j = 0; j < Math.min(nums2.length, k); j++) {
int sum = nums1[i] + nums2[j];
if (queue.size() < k) {
List<Integer> pair = new ArrayList<>();
pair.add(nums1[i]);
pair.add(nums2[j]);
queue.add(pair);
} else {
List<Integer> headPair = queue.peek();
int headSum = headPair.get(0) + headPair.get(1);
if (sum < headSum) {
queue.poll();
List<Integer> pair = new ArrayList<>();
pair.add(nums1[i]);
pair.add(nums2[j]);
queue.add(pair);
}
}
}
}
while (queue.size() > 0) {
resultList.add(queue.poll());
}
return resultList;
}
}
2.3 时间复杂度
O(Math.min(nums1.length, k) * Math.min(nums2.length, k))
悲催啊,超时了
2.4 空间复杂度
O(K)
3 优先级队列构建小顶堆
3.1 思路
- 构建一个小顶堆,堆顶元素存放的是和最小的两个数组的下标
- 初始时把0, 0 放入
- 不断从堆中取出堆顶元素,直到堆为空或者已经找齐k个最小对
- 每次取出堆顶元素,就代表是当前堆中和最小的对,把它们放入结果列表中。假设当前下标是nums1[m], nums2[n],那么下次放入的可能是nums1[m+1], nums2[n] 、nums1[m], nums2[n+1] ,都放入堆中
- 最后得到的就是结果
最大的好处就是能提前结束,不用把全部元素添加完成后再开始从堆中取元素。
3.2 代码
class Solution {
List<List<Integer>> resultList = new LinkedList<>();
public List<List<Integer>> kSmallestPairs(int[] nums1, int[] nums2, int k) {
// 小顶堆
PriorityQueue<List<Integer>> queue = new PriorityQueue<>(
(o1, o2) -> nums1[o1.get(0)] + nums2[o1.get(1)] - nums1[o2.get(0)] - nums2[o2.get(1)]);
Set<String> visitSet = new HashSet<>();
List<Integer> firstList = new ArrayList<>(2);
firstList.add(0);
firstList.add(0);
visitSet.add(0 + "#" + 0);
queue.add(firstList);
while (queue.size() > 0 && resultList.size() < k) {
List<Integer> indexList = queue.poll();
List<Integer> valueList = new ArrayList<>(2);
valueList.add(nums1[indexList.get(0)]);
valueList.add(nums2[indexList.get(1)]);
resultList.add(valueList);
if (resultList.size() == k) {
break;
}
int m = indexList.get(0);
int nextM = m + 1;
int n = indexList.get(1);
int nextN = n + 1;
if (nextM < nums1.length && (!visitSet.contains(nextM + "#" + n))) {
visitSet.add(nextM + "#" + n);
List<Integer> tmpList = new ArrayList<>(2);
tmpList.add(m + 1);
tmpList.add(n);
queue.add(tmpList);
}
if (nextN < nums2.length && (!visitSet.contains(m + "#" + nextN))) {
visitSet.add(m + "#" + nextN);
List<Integer> tmpList = new ArrayList<>(2);
tmpList.add(m);
tmpList.add(n + 1);
queue.add(tmpList);
}
}
return resultList;
}
}
3.3 时间复杂度
O(klog(k))
3.4 空间复杂度
O(K)
4 优先级队列构建小顶堆-优化
4.1 思路
3中需要建立一个HashSet,还需要拼接字符串来判定数字对是否已经放入堆,耗费了不少时间。其实转念想,如果我们固定一边,比如固定nums1的下标,取0到k-1,将他们和nums2的下标0的组合都放入小顶堆中。
然后,每次从堆中取出下标对,下次只需要将nums2的下标+1即可,而nums1的下标保持不变。
这么做的理由是,最多就找k个数字对,那么我们这样的做法肯定可以覆盖k个数字对。
4.2 代码
class Solution {
List<List<Integer>> resultList = new LinkedList<>();
public List<List<Integer>> kSmallestPairs(int[] nums1, int[] nums2, int k) {
// 小顶堆
PriorityQueue<List<Integer>> queue = new PriorityQueue<>(
(o1, o2) -> nums1[o1.get(0)] + nums2[o1.get(1)] - nums1[o2.get(0)] - nums2[o2.get(1)]);
Set<String> visitSet = new HashSet<>();
List<Integer> firstList = new ArrayList<>(2);
firstList.add(0);
firstList.add(0);
visitSet.add(0 + "#" + 0);
queue.add(firstList);
while (queue.size() > 0 && resultList.size() < k) {
List<Integer> indexList = queue.poll();
List<Integer> valueList = new ArrayList<>(2);
valueList.add(nums1[indexList.get(0)]);
valueList.add(nums2[indexList.get(1)]);
resultList.add(valueList);
if (resultList.size() == k) {
break;
}
int m = indexList.get(0);
int nextM = m + 1;
int n = indexList.get(1);
int nextN = n + 1;
if (nextM < nums1.length && (!visitSet.contains(nextM + "#" + n))) {
visitSet.add(nextM + "#" + n);
List<Integer> tmpList = new ArrayList<>(2);
tmpList.add(m + 1);
tmpList.add(n);
queue.add(tmpList);
}
if (nextN < nums2.length && (!visitSet.contains(m + "#" + nextN))) {
visitSet.add(m + "#" + nextN);
List<Integer> tmpList = new ArrayList<>(2);
tmpList.add(m);
tmpList.add(n + 1);
queue.add(tmpList);
}
}
return resultList;
}
}
4.3 时间复杂度
O(klog(k))
4.4 空间复杂度
O(K)