题目
思路一 优先队列/多路归并
由于两个数组都是已排序的,所以先把第一个数组中的第一个元素与nums2中的所有元素的下标放入优先队列中。维护小根堆,每次弹出一个元素(i,j),并且把(i+1,j)放入优先队列中,直到结果数组的大小大于等于k。
代码一
class Solution {
public:
vector<vector<int>> kSmallestPairs(vector<int>& nums1, vector<int>& nums2, int k) {
int m = nums1.size(), n = nums2.size();
auto cmp = [&](vector<int> a, vector<int> b){
return nums1[a[0]] + nums2[a[1]] > nums1[b[0]] + nums2[b[1]];
};
priority_queue<vector<int>, vector<vector<int>>, decltype(cmp)> pq(cmp);
vector<vector<int>> ans;
for(int i = 0; i < min(n,k); i++)
pq.push({0,i});
while(ans.size() < k && !pq.empty()){
int i = pq.top()[0], j = pq.top()[1];
ans.push_back({nums1[i],nums2[j]});
pq.pop();
if(i + 1 < m) pq.push({i + 1,j});
}
return ans;
}
};
思路二 二分
数对和两端的值分别为l = nums1[0] + nums2[0]和r = nums1[m - 1] + nums2[n - 1],因此可以在值域[l,r]上进行二分,找到第一个满足点对和小于等于x且数量超过k的值x。判断是否满足数量超过k用循环实现。先遍历两个数组,把点对和小于x的值添加进答案。对于点对和等于x的值,可以通过枚举nums1[i],然后在nums2上二分目标值x-nums1[i]的左右端点。
代码二
class Solution {
public:
vector<int> nums1,nums2;
int m,n;
vector<vector<int>> kSmallestPairs(vector<int>& n1, vector<int>& n2, int k) {
nums1=n1;nums2=n2;
m=nums1.size();n=nums2.size();
vector<vector<int>> ans;
int l=nums1[0]+nums2[0],r=nums1[m-1]+nums2[n-1];
while(l<r){
int mid=l+r>>1;
if(check(mid,k)) r=mid;
else l=mid+1;
}
int x=r;
for(int a:nums1)
for(int b:nums2)
if(a+b<x)
ans.emplace_back(initializer_list<int>{a,b});
else
break;
for(int i=0;i<m && ans.size()<k;i++){
int a=nums1[i],b=x-nums1[i];
l=0;r=n-1;
//找到满足小于等于b的最右边界值
while(l<r){
int mid=l+r+1>>1; //如果l和r差1,那么mid和r相等
if(nums2[mid]<=b) l=mid; //所以这里修改l
else r=mid-1;
}
if(nums2[r]!=b) continue;
int end=r;
l=0;r=n-1;
//找到满足大于等于b的最左边界值
while(l<r){
int mid=l+r>>1; //如果l和r差1,那么mid与l相等
if(nums2[mid]>=b) r=mid; //所以这里修改r
else l=mid+1;
}
int start=r;
for(int p=start;p<=end && ans.size()<k;p++)
ans.emplace_back(initializer_list<int>{a,b});
}
return ans;
}
bool check(int sum,int k){
int ans=0;
for(int i=0;i<m && ans<k;i++){
int x = sum - nums1[i];
//找到nums2中满足小于等于x的右边界
int l = 0, r = n - 1;
while(l < r){
int mid = l + r + 1 >> 1;
if(nums2[mid] <= x) l = mid;
else r = mid - 1;
}
if(nums2[r] > x) continue;
ans += r + 1;
}
return ans>=k;
}
};