次小生成树 / 严格次小生成树
题目链接:ybt高效进阶4-5-6 / LOJ 10133 / luogu P4180
题目大意
给你一个图,要你求这个图的次小生成树。
次小生成树就是边权之和大于最小生成树边权和的生成树的边权之和最小的一个。
一定要严格大于,保证一定存在,输出这个边权和。
思路
首先看到次小生成树,我们考虑先把最小生成树弄出来。
那弄出了最小生成树,次小生成树自然就是选一个边换。
那怎么弄呢?
直接枚举判断连通会超时,枚举在最小生成树上的似乎找不到方法加速。
那枚举不在最小生成树上的呢?
那你首先会想到有哪些最小生成树上的边是可以换的。
那就是那两个点在最小生成树上的链,可以通过 LCA 将它分成两条。
然后你会想到肯定是找着两条链上权值最大的一条换。
但首先它是有一点问题的,就是它是要次小生成树大小严格小于最小生成树的,那如果你找到权值最大的和那条不在最小生成树上的点一样大的时候,就不能选这条边。(不会大于,想想求最小生成树的过程就知道)
那就要选严格小于这条边的边,那也就是说我们要找的是着两条链上最大和小于最大的边中最大的(也就是严格第二大的)的边。
那你可以通过倍增 DP 来处理,然后再求 LCA 往上跳的时候更新即可。
有一点要注意的是,你在往上跳更新的时候,不能排序直接选第二个。
因为你要的第二是严格小于第一的,那如果前两个一样,那就要选第三个;前三个都一样,就选第四个;如果四个都一样,那就没有第二。
然后这样做就可以了。
代码
#include<cstdio>
#include<cstring>
#include<iostream>
#include<algorithm>
#define ll long long
using namespace std;
struct road {
int x, y, dis;
bool cho;
}a[300001];
struct node {
int x, to, nxt;
}e[600001];
int n, m, le[100001], KK, nn;
int f[100001], fa[21][100001], maxn[21][100001], maxn_sec[21][100001];
int deg[100001], tmpmaxn, tmpmaxn_sec;
ll ans, minn_ans = 1e15;
bool cmp(road x, road y) {
return x.dis < y.dis;
}
int find(int now) {
if (f[now] == now) return now;
return f[now] = find(f[now]);
}
void add(int x, int y, int z) {
e[++KK] = (node){z, y, le[x]}; le[x] = KK;
e[++KK] = (node){z, x, le[y]}; le[y] = KK;
}
void dfs(int now, int father) {
deg[now] = deg[father] + 1;
fa[0][now] = father;
for (int i = le[now]; i; i = e[i].nxt)
if (e[i].to != father)
dfs(e[i].to, now), maxn[0][e[i].to] = e[i].x;
}
int LCA(int x, int y) {//LCA(在跳的过程中记录边权最大和第二大的)
if (deg[y] > deg[x]) swap(x, y);
for (int i = 20; i >= 0; i--)
if (deg[fa[i][x]] >= deg[y]) {
//更新
if (maxn[i][x] > tmpmaxn) tmpmaxn = maxn[i][x];
else if (maxn[i][x] < tmpmaxn && maxn[i][x] > tmpmaxn_sec) tmpmaxn_sec = maxn[i][x];
if (maxn_sec[i][x] > tmpmaxn) tmpmaxn = maxn_sec[i][x];
else if (maxn_sec[i][x] < tmpmaxn && maxn_sec[i][x] > tmpmaxn_sec) tmpmaxn_sec = maxn_sec[i][x];
x = fa[i][x];
}
if (x == y) return x;
for (int i = 20; i >= 0; i--)
if (fa[i][x] != fa[i][y]) {
//更新
if (maxn[i][x] > tmpmaxn) tmpmaxn = maxn[i][x];
else if (maxn[i][x] < tmpmaxn && maxn[i][x] > tmpmaxn_sec) tmpmaxn_sec = maxn[i][x];
if (maxn_sec[i][x] > tmpmaxn) tmpmaxn = maxn_sec[i][x];
else if (maxn_sec[i][x] < tmpmaxn && maxn_sec[i][x] > tmpmaxn_sec) tmpmaxn_sec = maxn_sec[i][x];
if (maxn[i][y] > tmpmaxn) tmpmaxn = maxn[i][y];
else if (maxn[i][y] < tmpmaxn && maxn[i][y] > tmpmaxn_sec) tmpmaxn_sec = maxn[i][y];
if (maxn_sec[i][y] > tmpmaxn) tmpmaxn = maxn_sec[i][y];
else if (maxn_sec[i][y] < tmpmaxn && maxn_sec[i][y] > tmpmaxn_sec) tmpmaxn_sec = maxn_sec[i][y];
x = fa[i][x];
y = fa[i][y];
}
//记得更新最后的两个小跳
if (maxn[0][x] > tmpmaxn) tmpmaxn = maxn[0][x];
else if (maxn[0][x] < tmpmaxn && maxn[0][x] > tmpmaxn_sec) tmpmaxn_sec = maxn[0][x];
if (maxn_sec[0][x] > tmpmaxn) tmpmaxn = maxn_sec[0][x];
else if (maxn_sec[0][x] < tmpmaxn && maxn_sec[0][x] > tmpmaxn_sec) tmpmaxn_sec = maxn_sec[0][x];
if (maxn[0][y] > tmpmaxn) tmpmaxn = maxn[0][y];
else if (maxn[0][y] < tmpmaxn && maxn[0][y] > tmpmaxn_sec) tmpmaxn_sec = maxn[0][y];
if (maxn_sec[0][y] > tmpmaxn) tmpmaxn = maxn_sec[0][y];
else if (maxn_sec[0][y] < tmpmaxn && maxn_sec[0][y] > tmpmaxn_sec) tmpmaxn_sec = maxn_sec[0][y];
return fa[0][x];
}
int main() {
scanf("%d %d", &n, &m);
for (int i = 1; i <= m; i++)
scanf("%d %d %d", &a[i].x, &a[i].y, &a[i].dis);
sort(a + 1, a + m + 1, cmp);
for (int i = 1; i <= n; i++) f[i] = i;
for (int i = 1; i <= m; i++) {
int X = find(a[i].x), Y = find(a[i].y);
if (X == Y) continue;
nn++;
f[X] = Y;
add(a[i].x, a[i].y, a[i].dis);
a[i].cho = 1;
ans += 1ll * a[i].dis;
if (nn == n - 1) break;
}
memset(maxn, -1, sizeof(maxn));
memset(maxn_sec, -1, sizeof(maxn_sec));
dfs(1, 0);
for (int i = 1; i <= 20; i++)
for (int j = 1; j <= n; j++) {
//就是从两个区间的最大和第二大里面(一共四个数)找第二大的
maxn_sec[i][j] = min(maxn[i - 1][j], maxn[i - 1][fa[i - 1][j]]);
if (maxn[i - 1][j] == maxn[i - 1][fa[i - 1][j]]) maxn_sec[i][j] = min(maxn_sec[i - 1][j], maxn_sec[i][j]);
else maxn_sec[i][j] = max(maxn_sec[i - 1][j], maxn_sec[i][j]);
if (maxn[i - 1][j] == maxn[i - 1][fa[i - 1][j]] && maxn[i - 1][fa[i - 1][j]] == maxn_sec[i - 1][j]) maxn_sec[i][j] = max(maxn_sec[i][j], maxn_sec[i - 1][fa[i - 1][j]]);
else maxn_sec[i][j] = max(maxn_sec[i][j], maxn_sec[i - 1][fa[i - 1][j]]);
//最大直接
maxn[i][j] = max(maxn[i - 1][j], maxn[i - 1][fa[i - 1][j]]);
fa[i][j] = fa[i - 1][fa[i - 1][j]];
}
for (int i = 1; i <= m; i++)
if (!a[i].cho) {//处理每条不是最小生成树上的边
tmpmaxn = tmpmaxn_sec = -1;
int x = a[i].x, y = a[i].y;
int lca = LCA(x, y);
if (a[i].dis != tmpmaxn) minn_ans = min(minn_ans, 1ll * (a[i].dis - tmpmaxn));//如果更改了边权最大的边值不一样就改这条
else if (tmpmaxn_sec != -1) minn_ans = min(minn_ans, 1ll * (a[i].dis - tmpmaxn_sec));//否则就改边权第二大的(因为我们这里处理的是严格,所以一定会不一样)
}
printf("%lld", ans + minn_ans);
return 0;
}