题目大意
有一棵有 n 个节点的根节点为 1 的树,他只能走一条不经过重复节点的路径。
给出 q 个形如“x y”的询问,表示他不能走到 x 和 y 的子树中。现在他想知道对于每组询问,他能走的最长路径是多少,如果没有,输出0。
n, q<=10^5
【显然】
可走的地方还是一棵树,而我们要求的路径就是这棵树的直径。
【50%】n,q<=2000
每个询问暴力找直径。
【100%】n,q<=10^5
我们需要用线段树维护树的直径。
去掉两棵子树,相当于在dfs序上去掉两个区间(注意这两个区间可能有交或存在包含关系),我们把剩下的区间合并就行了。
//线段树维护树的直径:http://blog.csdn.net/rzo_kqp_orz/article/details/52280811
代码
#include<cmath>
#include<cstdio>
#include<algorithm>
#define fo(i,a,b) for(int i=a;i<=b;i++)
#define fd(i,a,b) for(int i=a;i>=b;i--)
using namespace std;
typedef long long LL;
const int maxn=(1e5)+5, MX=18;
struct TR{
int x,y,len;
TR(int X=0,int Y=0,int LEN=0) {x=X, y=Y, len=LEN;}
};
int n;
int tot,go[2*maxn],next[2*maxn],f1[maxn];
void ins(int x,int y)
{
go[++tot]=y;
next[tot]=f1[x];
f1[x]=tot;
}
int fa[2*maxn][MX+5],deep[maxn],ap[2*maxn],fir[2*maxn],Log[2*maxn],er[MX+5];
void rmq_pre()
{
fo(i,1,ap[0]) fa[i][0]=ap[i], Log[i]=log(i)/log(2);
fo(i,0,MX) er[i]=1<<i;
fo(j,1,MX)
fo(i,1,ap[0])
{
fa[i][j]=fa[i][j-1];
if (i+er[j-1]<=ap[0] && deep[fa[i+er[j-1]][j-1]]<deep[fa[i][j]])
fa[i][j]=fa[i+er[j-1]][j-1];
}
}
int lca(int x,int y)
{
x=fir[x], y=fir[y];
if (x>y) swap(x,y);
int t=Log[y-x+1];
return (deep[fa[x][t]]<deep[fa[y-er[t]+1][t]]) ?fa[x][t] :fa[y-er[t]+1][t] ;
}
int st[maxn],en[maxn],sum,Tbh[maxn];
void dfs_pre(int k,int last)
{
deep[k]=deep[last]+1;
ap[++ap[0]]=k, fir[k]=ap[0];
Tbh[++sum]=k, st[k]=sum;
for(int p=f1[k]; p; p=next[p]) if (go[p]!=last)
{
dfs_pre(go[p],k);
ap[++ap[0]]=k;
}
en[k]=sum;
}
TR tr[4*maxn];
int DIS(int x,int y) {return deep[x]+deep[y]-deep[lca(x,y)]*2;}
TR merge(TR a,TR b)
{
if (a.len==-1) return b;
TR re= (a.len>b.len) ?a :b;
if (DIS(a.x,b.x)>re.len) re=TR(a.x,b.x,DIS(a.x,b.x));
if (DIS(a.x,b.y)>re.len) re=TR(a.x,b.y,DIS(a.x,b.y));
if (DIS(a.y,b.x)>re.len) re=TR(a.y,b.x,DIS(a.y,b.x));
if (DIS(a.y,b.y)>re.len) re=TR(a.y,b.y,DIS(a.y,b.y));
return re;
}
void tr_js(int k,int l,int r)
{
if (l==r)
{
tr[k].x=tr[k].y=Tbh[l];
tr[k].len=0;
return;
}
int t=k<<1, t1=(l+r)>>1;
tr_js(t,l,t1), tr_js(t+1,t1+1,r);
tr[k]=merge(tr[t],tr[t+1]);
}
TR tr_cx(int k,int l,int r,int x,int y)
{
if (l==x && r==y) return tr[k];
int t=k<<1, t1=(l+r)>>1;
if (y<=t1) return tr_cx(t,l,t1,x,y);
else if (x>t1) return tr_cx(t+1,t1+1,r,x,y);
else return merge(tr_cx(t,l,t1,x,t1),tr_cx(t+1,t1+1,r,t1+1,y));
}
int q;
int main()
{
freopen("snow.in","r",stdin);
freopen("snow.out","w",stdout);
scanf("%d %d",&n,&q);
fo(i,1,n-1)
{
int x,y;;
scanf("%d %d",&x,&y);
ins(x,y), ins(y,x);
}
dfs_pre(1,0);
rmq_pre();
tr_js(1,1,n);
while (q--)
{
int x,y;
scanf("%d %d",&x,&y);
if (st[x]>st[y]) swap(x,y);
TR ans=TR(0,0,-1);
if (1<st[x]) ans=merge(ans,tr_cx(1,1,n,1,st[x]-1));
if (en[x]+1<st[y]) ans=merge(ans,tr_cx(1,1,n,en[x]+1,st[y]-1));
int En=max(en[x],en[y]);
if (En<n) ans=merge(ans,tr_cx(1,1,n,En+1,n));
printf("%d\n",(ans.len==-1) ?0 :ans.len );
}
}