题目链接: 三元组最小距离
定义三元组 $(a, b, c)$($a,b,c$ 均为整数)的距离 $D=|a-b|+|b-c|+|c-a|$。
给定 $3$ 个非空整数集合 $S_1, S_2, S_3$,按升序分别存储在 $3$ 个数组中。
请设计一个尽可能高效的算法,计算并输出所有可能的三元组 $(a, b, c)$($a \in S_1,b \in S_2,c \in S_3$)中的最小距离。
例如 $S_1=\{-1, 0, 9\}, S_2=\{-25, -10, 10, 11\}, S_3=\{2, 9, 17, 30, 41\}$ 则最小距离为 $2$,相应的三元组为 $(9,10,9)$。
输入格式
第一行包含三个整数 $l,m,n$,分别表示 $S_1,S_2,S_3$ 的长度。
第二行包含 l 个整数,表示 $S_1$ 中的所有元素。
第三行包含 $m$ 个整数,表示 $S_2$ 中的所有元素。
第四行包含 $n$ 个整数,表示 $S_3$ 中的所有元素。
以上三个数组中的元素都是按升序顺序给出的。
输出格式
输出三元组的最小距离。
数据范围
$1 \le l,m,n \le 10^5$,
所有数组元素的取值范围 $[-10^9,10^9]$。
输入样例:
3 4 5
-1 0 9
-25 -10 10 11
2 9 17 30 41
输出样例:
2
-
暴力想法: 枚举所有可能的答案排列 从小到大 [S1 S2 S3], [S3 S2 S1], [S2, S3, S1], [S1, S3, S2], [S2, S1, S3], [S3, S1, S2], 如果单纯暴力复杂度就是O(6n^3) 铁定过不了,这时候我们可以选择枚举中间值属于哪个素组, 在每个枚举中我们已经确定了中间的数,那么我们就可以根据这个中间值在另外两个数组中二分找到符合排列的数,这里又有两个前后问题,我们还是直接枚举比较,比如确定中间的数来自于S2,那么答案排列可能是S1,S2,S3, 或者 S3, S2, S1这样算下来总的时间复杂度O(3nlogn).
-
代码
#include<bits/stdc++.h>
using namespace std;
const int N = 1e5 + 10;
typedef long long LL;
int main()
{
int l, n, m; cin >> l >> n >> m;
vector<int> a(l), b(n), c(m);
for(int i = 0; i < l; i ++) cin >> a[i];
for(int i = 0; i < n; i ++) cin >> b[i];
for(int i = 0; i < m; i ++) cin >> c[i];
// a中元素当中间值
LL ans = 1e12;
for(auto it : a)
{
//b a c
int dexb = upper_bound(b.begin(), b.end(), it) - b.begin();
if(dexb != 0) dexb --;
int dexc = lower_bound(c.begin(), c.end(), it) - c.begin();
if(dexc == c.size()) dexc --;
int xa = it, xb = b[dexb], xc = c[dexc];
ans = min(ans, 1ll*abs(xa - xb) + abs(xa - xc) + abs(xb - xc));
//c a b
dexc = upper_bound(c.begin(), c.end(), it) - c.begin();
if(dexc != 0) dexc --;
dexb = lower_bound(b.begin(), b.end(), it) - b.begin();
if(dexb == b.size()) dexb --;
xa = it, xb = b[dexb], xc = c[dexc];
ans = min(ans, 1ll*abs(xa - xb) + abs(xa - xc) + abs(xb - xc));
}
// b 当中间值
for(auto it : b)
{
//a b c
int dexa = upper_bound(a.begin(), a.end(), it) - a.begin();
if(dexa != 0) dexa --;
int dexc = lower_bound(c.begin(), c.end(), it) - c.begin();
if(dexc == c.size()) dexc --;
int xb = it, xa = a[dexa], xc = c[dexc];
ans = min(ans, 1ll*abs(xa - xb) + abs(xa - xc) + abs(xb - xc));
//c b a
dexc = upper_bound(c.begin(), c.end(), it) - c.begin();
if(dexc != 0) dexc --;
dexa = lower_bound(a.begin(), a.end(), it) - a.begin();
if(dexa == a.size()) dexa --;
xb = it, xa = a[dexa], xc = c[dexc];
ans = min(ans, 1ll*abs(xa - xb) + abs(xa - xc) + abs(xb - xc));
}
// c 当中间值
for(auto it : c)
{
//a c b
int dexa = upper_bound(a.begin(), a.end(), it) - a.begin();
if(dexa != 0) dexa --;
int dexb = lower_bound(b.begin(), b.end(), it) - b.begin();
if(dexb == b.size()) dexb --;
int xc = it, xa = a[dexa], xb = b[dexb];
ans = min(ans, 1ll*abs(xa - xb) + abs(xa - xc) + abs(xb - xc));
//b c a
dexb = upper_bound(b.begin(), b.end(), it) - b.begin();
if(dexb != 0) dexb --;
dexa = lower_bound(a.begin(), a.end(), it) - a.begin();
if(dexa == a.size()) dexa --;
xc = it, xa = a[dexa], xb = b[dexb];
ans = min(ans, 1ll*abs(xa - xb) + abs(xa - xc) + abs(xb - xc));
}
cout << ans << endl;
return 0;
}
- 另一个思路:滑动窗口 O(3n*log(3n))实现
我们可以标记好每个数来自哪个数组然后统一排序,滑动窗口找相邻三个不同归属的数进行答案比较,这个和上面的相比实现更简单些! - 代码
#include<bits/stdc++.h>
using namespace std;
const int N = 1e5 + 10;
typedef long long LL;
int main()
{
int l, n, m; cin >> l >> n >> m;
// 可以结构体数组,或者pair 存数值与所属关系,用multimap纯属个人偷懒行为
multimap<int, int> cnt;
for(int i = 0; i < l; i ++)
{
int x; cin >> x;
cnt.insert(pair<int, int>(x, 1));
}
for(int i = 0; i < n; i ++)
{
int x; cin >> x;
cnt.insert(pair<int, int>(x, 2));
}
for(int i = 0; i < m; i ++)
{
int x; cin >> x;
cnt.insert(pair<int, int>(x, 3));
}
LL ans = 1e18;
vector<LL> st(4, 1e10);// 将st1,2,3赋值1e10表示空
for(auto [a,b] : cnt)
{
st[b] = a;
if(st[1] != 1e10 && st[2] != 1e10 && st[3] != 1e10)
{
ans = min(ans, 1ll*abs(st[1]-st[2]) + abs(st[1]-st[3]) + abs(st[2]-st[3]));
}
}
cout << ans << endl;
return 0;
}
- 进阶思路:O(3n)实现 三路归并
假设x < y < z 我们化简 |x-y|+|y-z|+|z-x| 后发现 我们每一次算出的答案都是2*(max-min)
所以我们只需要每个三元组中的最大值与最小值即可,所以我们尽可能让max与min逼近
这时候就可以三路归并,每次只让最小的去靠近最大的值,实现也很简单。 - 代码
#include<bits/stdc++.h>
using namespace std;
const int N = 1e5 + 10;
typedef long long LL;
int main()
{
int l, n, m; cin >> l >> n >> m;
vector<int> a(l), b(n), c(m);
for(int i = 0; i < l; i ++) cin >> a[i];
for(int i = 0; i < n; i ++) cin >> b[i];
for(int i = 0; i < m; i ++) cin >> c[i];
LL ans = 1e18;
for(int i = 0, j = 0, k = 0; i < l && j < n && k < m;)
{
int x = a[i], y = b[j], z = c[k];
ans = min(ans, 2*(1ll*max(x, max(y, z)) - min(x, min(y, z))));
if(x <= y && x <= z) i ++;
else if(y <= x && y <= z) j ++;
else k ++;
}
cout << ans << endl;
return 0;
}