[六省联考2017]摧毁“树状图”
分析
题目大意:删去树上两条边不相交路径的剩余联通块个数最大值。
一道很烦很烦的树形Dp。
套路就是统计过根和不过根路径。
把路径看成线,那么子树合并就是线的拼接,我们称能拼接的线为一个线头。相当于是挂在子树根上的一条链。
f
f
f是不过子树内部不过根一条路径的答案。
g
g
g是不过子树内部不过根两条的答案。
h
h
h过子树根一个线头一条路径的答案。
h
1
h_1
h1过子树根两个线头一条路径的答案。
l
l
l过子树根一个或三个线头两条路径的答案。
l
1
l_1
l1过子树根两个或四个线头两条路径的答案。
用之前的信息转移一波即可。
看代码注释
#include<bits/stdc++.h>
const int N = 1e5 + 10;
int ri() {
char c = getchar(); int x = 0, f = 1; for(;c < '0' || c > '9'; c = getchar()) if(c == '-') f = -1;
for(;c >= '0' && c <= '9'; c = getchar()) x = (x << 1) + (x << 3) - '0' + c; return x * f;
}
int tp, to[N << 1], nx[N << 1], pr[N];
void add(int u, int v) {to[++tp] = v; nx[tp] = pr[u]; pr[u] = tp;}
void adds(int u, int v) {add(u, v); add(v, u);}
struct Data {
int f, g, h, h1, l, l1;
void clear() {f = g = h = h1 = l = l1 = 0;}
}t[N];
void Up(int &a, int b) {a = std::max(a, b);}
void Dp(int u, int fa) {
t[u].clear(); int d = 0, mx = 0;
for(int i = pr[u]; i; i = nx[i])
if(to[i] != fa){
Dp(to[i], u);
int f = t[u].f, g = t[u].g, h = t[u].h + 1, h1 = t[u].h1 + 1, l = t[u].l + 1, l1 = t[u].l1 + 1;
int t1 = std::max(t[to[i]].h, t[to[i]].h1);
//选择过当前子树根的一条链
Up(f, t[to[i]].f); //保存子树内不过子树根的答案
Up(f, t1 + 1); //用过子树的链更新 ,注意这样划分u会多出一个联通块
Up(g, t[to[i]].g); //保留子树内不过子树根的答案
Up(g, std::max(t[to[i]].l, t[to[i]].l1) + 1); //过子树的根的答案
Up(g, t[u].f + t1); //用之前子树的某条链+当前过子树根的答案
Up(h, t[to[i]].h + d); //子树伸出来
Up(l, t[to[i]].l + d); //同上
Up(h1, t[u].h + t[to[i]].h); //之前子树+当前子树伸出来
//l的一个线头
Up(l, t[u].h + t[to[i]].f); //之前子树的h+不过当前子树根的某条链
Up(l, t[u].h + t1); //之前子树的h+过当前子树根的某条链
Up(l, t[to[i]].h + mx); //当前子树的h+随便过之前子树的某条链
//l的三个线头
Up(l, t[to[i]].h + t[u].h1); //当前子树的h+过根的在之前子树的路径
Up(l1, t[u].h1 + t1); //之前子树的h1+当前子树过根的随便一条链
Up(l1, t[u].h + t[to[i]].l); //当前子树的l+之前子树的h伸出去接
Up(l1, t[u].l + t[to[i]].h); //之前子树的l+当前子树的h伸出去接
Up(l1, t[u].h1 + t[to[i]].f); //之前子树的h1+当前子树不过根随意一条链
mx = std::max(mx + 1, d + std::max(t1, t[to[i]].f)); //之前子树随便一条链的答案
++d;
t[u].f = f; t[u].g = g; t[u].h = h; t[u].l = l; t[u].h1 = h1; t[u].l1 = l1;
}
Up(t[u].h, d);
}
int main() {
for(int T = ri(), x = ri(); T--;) {
int n = ri(); tp = 0;
for(int i = 1;i <= n; ++i)
pr[i] = 0;
if(x) ri(), ri();
if(x == 2) ri(), ri();
for(int i = 1;i < n; ++i)
adds(ri(), ri());
Dp(1, 0);
printf("%d\n", std::max(t[1].g, std::max(t[1].l1, t[1].l)));
}
return 0;
}