题目
题目大意
第一问:找一个点,从它向外扩散,最少的扩散次数。
第二问:给定两个点,求从这两个点向外扩散的最少次数。
题解
对于第一问,
很自然的想到一种
O(n2)
O
(
n
2
)
的dp
设
fi
f
i
表示从第i个点开始扩散,使整个子树的节点都被占领的最少次数。
转移的时候,就将所有的儿子节点从大到小排序,先向大的扩散。
考虑换根,
如果是暴力换根,每次重新排序,这样显然超时。
设
upi
u
p
i
表示i它的父亲连出去所有边中除去i的f值,
每次利用
upi
u
p
i
和i的儿子的f求出i儿子的up。
对于第二问,
既然两个点都是固定的,
将这两个点的连线来出来,
这条路径上必然有一个分界点,上面被一个点占领,下面被另外一个点占领,
二分这个分界点,分别对两个部分dp一下,就可以了。
code
#include<queue>
#include<cstdio>
#include<iostream>
#include<algorithm>
#include <cstring>
#include <string.h>
#include <cmath>
#include <math.h>
#include <time.h>
#include <vector>
#define ll long long
#define N 300003
#define M 103
#define db double
#define P putchar
#define G getchar
#define inf 998244353
#define fi first
#define se second
using namespace std;
char ch;
void read(int &n)
{
n=0;
ch=G();
while((ch<'0' || ch>'9') && ch!='-')ch=G();
ll w=1;
if(ch=='-')w=-1,ch=G();
while('0'<=ch && ch<='9')n=(n<<3)+(n<<1)+ch-'0',ch=G();
n*=w;
}
int max(int a,int b){return a>b?a:b;}
int min(int a,int b){return a<b?a:b;}
ll abs(ll x){return x<0?-x:x;}
ll sqr(ll x){return x*x;}
void write(ll x){if(x>9) write(x/10);P(x%10+'0');}
int n,a,b,x,y,ans,tot;
int f[N],t[N],lst[N],nxt[N<<1],to[N<<1],g[N],up[N],pr[N],su[N];
int l,r,mid,p[N],m;
bool cmp(int a,int b)
{
return a>b;
}
void ins(int x,int y)
{
nxt[++tot]=lst[x];
to[tot]=y;
lst[x]=tot;
}
void work(int x,int fa)
{
int m=0;
for(int i=lst[x];i;i=nxt[i])
if(to[i]!=fa)t[++m]=f[to[i]];
sort(t+1,t+1+m,cmp);f[x]=0;
for(int i=1;i<=m;i++)
f[x]=max(f[x],i+t[i]);
}
void works(int x,int fa,int ttt)
{
int m=0;
for(int i=lst[x];i;i=nxt[i])
if(to[i]!=fa && to[i]!=ttt)t[++m]=f[to[i]];
sort(t+1,t+1+m,cmp);f[x]=0;
for(int i=1;i<=m;i++)
f[x]=max(f[x],i+t[i]);
}
void dfs(int x,int fa)
{
g[x]=fa;
for(int i=lst[x];i;i=nxt[i])
if(to[i]!=fa)dfs(to[i],x);
work(x,fa);
}
void root(int x)
{
vector<pair<int,int> >q;
for(int i=lst[x];i;i=nxt[i])
if(to[i]!=g[x])q.push_back(make_pair(f[to[i]],to[i]));
if(g[x])q.push_back(make_pair(up[x],x));
sort(q.begin(),q.end(),greater< pair<int,int> >());
pr[0]=q[0].fi+1;
for (int i=1;i<q.size();i++)
pr[i]=max(pr[i-1],q[i].fi+1+i);
su[q.size()]=0;
for (int i=q.size()-1;i>=0;i--)
su[i]=max(su[i+1],q[i].fi+1+i);
ans=min(ans,su[0]);
for (int i=0;i<q.size();i++)
if (q[i].se!=x)up[q[i].se]=max(i?pr[i-1]:0,su[i+1]-1);
for(int i=lst[x];i;i=nxt[i])
if(to[i]!=g[x])root(to[i]);
}
void pd(int mid)
{
works(p[mid+1],p[mid+2],p[mid]);
for(int i=mid+2;i<=m;i++)
work(p[i],p[i+1]);
works(p[mid],p[mid-1],p[mid+1]);
for(int i=mid-1;i;i--)
work(p[i],p[i-1]);
ans=min(ans,max(f[a],f[b]));
}
int main()
{
read(n);read(a);read(b);ans=2147483647;
for(int i=1;i<n;i++)
read(x),read(y),ins(x,y),ins(y,x);
dfs(a,0);
root(a);
write(ans);
P('\n');
for(x=b,m=0;x!=0;x=g[x])p[++m]=x;
l=1;r=m-1;
while(l<r)
{
mid=(l+r)>>1;
pd(mid);
if(f[a]<f[b])r=mid;else l=mid+1;
}
mid=(l+r)>>1;
pd(mid);
write(ans);
P('\n');
return 0;
}