Description
Input
Output
Sample Input
输入1:
6 2 1
1 2
2 3
2 4
1 5
5 6
输入2:
10 1 2
1 2
2 5
1 3
1 4
4 6
6 7
3 8
3 9
3 10
输入3:
17 1 2
1 3
1 4
4 6
6 7
3 8
3 9
3 10
1 13
13 5
13 11
13 12
13 14
14 15
15 16
15 17
14 2
Sample Output
输出1:
3
2
输出2:
4
4
输出3:
5
5
Data Constraint
Hint
Solution
假定以 a 为根节点,一个比较容易想的树形DP。
设
f[i] 表示到达 i 号点还要多少步才能走完以i 为根的子树。那么 f[i] 的值可以通过 i 的儿子节点转移过来。
设
i 的儿子节点按其 f 值从大到小排序后为son1−k :f[i]=maxj=1k(f[sonj]+j)因为走到分叉口,一定是先走需要步数多的子树。
之后我们换根到每个点(换根只需改两个点的 f 值,暴力计算即可),取最优答案即可。
那如果是固定以
a,b 为根的第二问该怎么处理呢?其实与第一问本质是一样的,考虑在路径 a,b 上找一条边断掉。
那就形成了两棵分别以 a,b 为根的树,用同样的方法求出 f[a] 和 f[b] 。
显然是步数越平均越好,所以我们可以二分断在哪里。
如果 f[a]<f[b] ,就把断点往 b 点移,否则往
a 点移。二分过程中更新答案即可, ans=min(ans,max(f[a],f[b])) 。
时间复杂度谜之 O(N log N) 。
Code
#include<cstdio>
#include<iostream>
#include<algorithm>
#include<cctype>
#include<vector>
using namespace std;
typedef pair<int,int> PI;
const int N=3e5+5;
int n,a,b,tot,ans=1e9;
int f[N],g[N],pre[N],suf[N],fa[N],q[N];
vector<int>e[N];
inline int read()
{
int X=0,w=0; char ch=0;
while(!isdigit(ch)) w|=ch=='-',ch=getchar();
while(isdigit(ch)) X=(X<<3)+(X<<1)+(ch^48),ch=getchar();
return w?-X:X;
}
inline int min(int x,int y)
{
return x<y?x:y;
}
inline int max(int x,int y)
{
return x>y?x:y;
}
inline void insert(int x,int y)
{
e[x].push_back(y);
e[y].push_back(x);
}
inline void erase(int x,int y)
{
e[x].erase(find(e[x].begin(),e[x].end(),y));
e[y].erase(find(e[y].begin(),e[y].end(),x));
}
int dfs(int x)
{
vector<int>h;
for(int i=0;i<e[x].size();i++)
if(e[x][i]^fa[x])
{
fa[e[x][i]]=x;
h.push_back(dfs(e[x][i]));
}
sort(h.begin(),h.end(),greater<int>());
int mx=0;
for(int i=0;i<h.size();i++) mx=max(mx,h[i]+i+1);
return f[x]=mx;
}
void change(int x)
{
vector<PI>t;
for(int i=0;i<e[x].size();i++)
if(e[x][i]^fa[x]) t.push_back(make_pair(f[e[x][i]],e[x][i]));
if(fa[x]) t.push_back(make_pair(g[x],x));
sort(t.begin(),t.end(),greater<PI>());
int m=t.size();
pre[0]=t[0].first+1,suf[m]=0;
for(int i=1;i<m;i++) pre[i]=max(pre[i-1],t[i].first+1+i);
for(int i=m-1;i>=0;i--) suf[i]=max(suf[i+1],t[i].first+1+i);
ans=min(ans,pre[m-1]);
for(int i=0;i<m;i++)
if(t[i].second^x) g[t[i].second]=max(i?pre[i-1]:0,suf[i+1]-1);
for(int i=0;i<e[x].size();i++)
if(e[x][i]^fa[x]) change(e[x][i]);
}
int main()
{
n=read(),a=read(),b=read();
for(int i=1;i<n;i++) insert(read(),read());
dfs(a),change(a);
printf("%d\n",ans);
for(int i=b;i;i=fa[i]) q[++tot]=i;
reverse(q+1,q+1+tot);
int l=2,r=tot;
fa[b]=0,ans=1e9;
while(l<=r)
{
int mid=l+r>>1;
erase(q[mid-1],q[mid]);
int x=dfs(a),y=dfs(b);
insert(q[mid-1],q[mid]);
ans=min(ans,max(x,y));
if(x<y) l=mid+1; else r=mid-1;
}
printf("%d\n",ans);
return 0;
}