Description:
题解:
第一问是经典的dp。
先随便选一个作为根。
设 fi 表示i已经被占,占领其子树需要的最少步数。
转移就把i的子节点的f值从大到小排序, fi=max(fson+numson)
之后考虑换根,没有什么区别,维护前缀max,后缀max就行了。
第二问的话考场时没有想到,太弱了。
把a-b的路径提出来, O(n2) 就是枚举在哪里断掉,分别dp取max。
那么贪心(直觉)告诉我们使两边的值越平均越好。
于是二分就出世了。
当然如果是把差取绝对值,就是三分了。
Samjia2000 dalao 有势能分析的一次dp是O(n)的做法,code 7000bytes+……鏼鏼发抖~
Code:
#include<cstdio>
#include<cstring>
#include<algorithm>
#define fo(i, x, y) for(int i = x; i <= y; i ++)
#define fd(i, x, y) for(int i = x; i >= y; i --)
#define max(a, b) ((a) > (b) ? (a) : (b))
#define min(a, b) ((a) < (b) ? (a) : (b))
using namespace std;
const int N = 3e5 + 5;
int n, a, b, x, y, rt;
int final[N], next[N * 2], to[N * 2], tot;
void link(int x, int y) {
next[++ tot] = final[x], to[tot] = y, final[x] = tot;
next[++ tot] = final[y], to[tot] = x, final[y] = tot;
}
int f[N], bz[N], d[N];
void dg(int x) {
bz[x] = 1;
for(int i = final[x]; i; i = next[i]) {
if(bz[to[i]]) continue; dg(to[i]);
}
d[0] = 0;
for(int i = final[x]; i; i = next[i]) {
if(bz[to[i]]) continue; d[++ d[0]] = f[to[i]];
}
sort(d + 1, d + d[0] + 1);
f[x] = -1e9;
fo(i, 1, d[0]) f[x] = max(f[x], d[i] + (d[0] - i + 1));
if(d[0] == 0) f[x] = 0;
bz[x] = 0;
}
bool rank(int x, int y) {return f[x] > f[y];}
int p[N], q[N], c[N], ans, la[N], ans2;
void dfs(int x) {
bz[x] = 1;
d[0] = 0;
for(int i = final[x]; i; i = next[i]) {
if(bz[to[i]]) continue; d[++ d[0]] = to[i];
}
f[0] = c[x]; if(x != rt) d[++ d[0]] = 0;
sort(d + 1, d + d[0] + 1, rank);
p[0] = q[d[0] + 1] = -1e9;
fo(i, 1, d[0]) p[i] = max(p[i - 1], f[d[i]] + i);
fd(i, d[0], 1) q[i] = max(q[i + 1], f[d[i]] + i - 1);
fo(i, 1, d[0]) if(d[i] != 0)
c[d[i]] = max(p[i - 1], q[i + 1]);
ans = min(ans, p[d[0]]);
for(int i = final[x]; i; i = next[i]) {
if(bz[to[i]]) continue; dfs(to[i]);
}
bz[x] = 0;
}
int dd[N];
int main() {
freopen("game.in", "r", stdin);
freopen("game.out", "w", stdout);
scanf("%d %d %d", &n, &a, &b);
fo(i, 1, n - 1) {
scanf("%d %d", &x, &y);
link(x, y);
}
ans = 1e9; rt = (n + 1) / 2;
dg(rt); dfs(rt);
printf("%d\n", ans);
d[d[0] = 1] = a;
fo(i, 1, d[0]) {
int x = d[i];
for(int j = final[x]; j; j = next[j]) {
if(!bz[to[j]]) bz[to[j]] = 1, la[to[j]] = x, d[++ d[0]] = to[j];
}
}
ans2 = 1e9;
memset(bz, 0, sizeof bz);
d[0] = 0; int x = b;
while(x != a) dd[++ dd[0]] = x, x = la[x];
dd[dd[0] + 1] = a;
for(int l = 1, r = dd[0]; l <= r; ) {
int m = l + r >> 1;
bz[dd[m]] = 1;
dg(a); ans = f[a];
int ans3 = ans; ans = 1e9;
bz[dd[m]] = 0; bz[dd[m + 1]] = 1;
dg(b); ans = f[b];
bz[dd[m + 1]] = 0;
ans2 = min(ans2, max(ans, ans3));
if(ans3 > ans) l = m + 1; else r = m - 1;
}
printf("%d\n", ans2);
}