Description
给出一棵有根树,1为根。
给出q次询问,每次询问x,y表示除x,y为根的子树外,剩下的树的直径的长度。
n,q<=10^5
Solution
既然和子树有关,那么我们就维护树的dfs序。
然后每个区间维护直径的长度。用线段树,同51nod1766树上的最远点对.
那么不能用x,y为根的子树就是不能用某两个区间。这样就把原序列分成了最多三个区间,合并起来就好了。
Code
#include<cmath>
#include<cstdio>
#include<cstring>
#include<algorithm>
#define fo(i,a,b) for(int i=a;i<=b;i++)
#define fd(i,a,b) for(int i=a;i>=b;i--)
#define rep(i,a) for(int i=last[a];i;i=next[i])
#define N 100005
using namespace std;
struct note{int a,b;}tr[N*5],p[2],tmp;
bool cmp(note x,note y) {return x.a<y.a;}
int n,m,x,y,l,tot,top,c[N],d[N],dfn[N],q[N*2],fir[N],size[N],f[N*2][18];
int t[N*2],next[N*2],last[N],mi[18];
void add(int x,int y) {
t[++l]=y;next[l]=last[x];last[x]=l;
}
void dfs(int x,int y) {
dfn[x]=++tot;c[tot]=x;size[x]=1;d[x]=d[y]+1;q[++top]=x;fir[x]=top;
rep(i,x) if (t[i]!=y) dfs(t[i],x),size[x]+=size[t[i]],q[++top]=x;
}
int lca(int x,int y) {
x=fir[x];y=fir[y];
if (x>y) swap(x,y);
int z=log2(y-x+1);
if (d[q[f[x][z]]]<d[q[f[y-mi[z]+1][z]]]) return q[f[x][z]];
else return q[f[y-mi[z]+1][z]];
}
int len(int x,int y) {
int z=lca(x,y);
return d[x]+d[y]-2*d[z];
}
note merge(note y,note z) {
note x;int mx=0,l;
if (!(y.a+y.b)) return z;
l=len(y.a,y.b);if (l>mx) mx=l,x.a=y.a,x.b=y.b;
l=len(z.a,z.b);if (l>mx) mx=l,x.a=z.a,x.b=z.b;
l=len(y.a,z.a);if (l>mx) mx=l,x.a=y.a,x.b=z.a;
l=len(y.a,z.b);if (l>mx) mx=l,x.a=y.a,x.b=z.b;
l=len(y.b,z.a);if (l>mx) mx=l,x.a=y.b,x.b=z.a;
l=len(y.b,z.b);if (l>mx) mx=l,x.a=y.b,x.b=z.b;
return x;
}
void build(int v,int l,int r) {
if (l==r) {tr[v].a=tr[v].b=c[l];return;}
int m=(l+r)/2;
build(v*2,l,m);build(v*2+1,m+1,r);
tr[v]=merge(tr[v*2],tr[v*2+1]);
}
note find(int v,int l,int r,int x,int y) {
if (l==x&&r==y) return tr[v];
int m=(l+r)/2;
if (y<=m) return find(v*2,l,m,x,y);
else if (x>m) return find(v*2+1,m+1,r,x,y);
else return merge(find(v*2,l,m,x,m),find(v*2+1,m+1,r,m+1,y));
}
int main() {
freopen("snow.in","r",stdin);
freopen("snow.out","w",stdout);
scanf("%d%d",&n,&m);
fo(i,1,n-1) scanf("%d%d",&x,&y),
add(x,y),add(y,x);dfs(1,0);mi[0]=1;
fo(i,1,top) f[i][0]=i;
fo(i,1,log2(top)) mi[i]=mi[i-1]*2;
fo(j,1,log2(top))
fo(i,1,top-mi[j]+1)
if (d[q[f[i][j-1]]]<d[q[f[i+mi[j-1]][j-1]]]) f[i][j]=f[i][j-1];
else f[i][j]=f[i+mi[j-1]][j-1];
build(1,1,n);
for(;m;m--) {
scanf("%d%d",&x,&y);
if (x==1||y==1) {printf("0\n");continue;}
p[0].a=dfn[x];p[0].b=dfn[x]+size[x]-1;
p[1].a=dfn[y];p[1].b=dfn[y]+size[y]-1;
sort(p,p+2,cmp);tmp.a=tmp.b=0;
if (p[0].a>1) tmp=merge(tmp,find(1,1,n,1,p[0].a-1));
if (p[0].b+1<=p[1].a-1) tmp=merge(tmp,find(1,1,n,p[0].b+1,p[1].a-1));
int ri=max(p[0].b,p[1].b);
if (ri<n) tmp=merge(tmp,find(1,1,n,ri+1,n));
printf("%d\n",len(tmp.a,tmp.b));
}
}