Description
给出一棵n个点的树,每次询问编号在[a,b]中的一个点和编号在[c,d]一个点的最远距离。
n<=10^5
Solution
我们知道,树上最远的距离是树的直径。
然后,直径可以由两个点集中的直径的总共四个端点两两配对得到。
于是我们就可以用线段树来维护这个东西。
注意求距离要用欧拉序列,不能用倍增,否则会爆炸性超时。
Code
#include<cmath>
#include<cstdio>
#include<cstring>
#include<iostream>
#include<algorithm>
#define fo(i,a,b) for(int i=a;i<=b;i++)
#define fd(i,a,b) for(int i=a;i>=b;i--)
#define rep(i,a) for(int i=last[a];i;i=next[i])
#define N 100005
using namespace std;
struct note{int a,b;}tr[N*5];
bool cmp(note x,note y) {return x.a<y.a;}
int n,m,x,y,z,k,l,tot,d[N],dfn[N*2],fir[N],sum[N],f[N*2][18];
int t[N*2],next[N*2],v[N*2],last[N],mi[18];
int get() {
char ch;while (!isdigit(ch=getchar()));
int o=ch-48;while (isdigit(ch=getchar())) o=o*10+ch-48;
return o;
}
void add(int x,int y,int z) {
t[++l]=y;v[l]=z;next[l]=last[x];last[x]=l;
}
void dfs(int x,int y) {
d[x]=d[y]+1;dfn[++tot]=x;fir[x]=tot;
rep(i,x) if (t[i]!=y) sum[t[i]]=sum[x]+v[i],dfs(t[i],x),dfn[++tot]=x;;
}
int lca(int x,int y) {
x=fir[x];y=fir[y];
if (x>y) swap(x,y);
int z=log2(y-x+1);
if (d[dfn[f[x][z]]]<d[dfn[f[y-mi[z]+1][z]]]) return dfn[f[x][z]];
else return dfn[f[y-mi[z]+1][z]];
}
int len(int x,int y) {
int z=lca(x,y);
return sum[x]+sum[y]-2*sum[z];
}
note merge(note y,note z,int bz) {
note x;int mx=0,l;
if (!bz) {
l=len(y.a,y.b);if (l>mx) mx=l,x.a=y.a,x.b=y.b;
l=len(z.a,z.b);if (l>mx) mx=l,x.a=z.a,x.b=z.b;
}
l=len(y.a,z.a);if (l>mx) mx=l,x.a=y.a,x.b=z.a;
l=len(y.a,z.b);if (l>mx) mx=l,x.a=y.a,x.b=z.b;
l=len(y.b,z.a);if (l>mx) mx=l,x.a=y.b,x.b=z.a;
l=len(y.b,z.b);if (l>mx) mx=l,x.a=y.b,x.b=z.b;
return x;
}
void build(int v,int l,int r) {
if (l==r) {tr[v].a=tr[v].b=l;return;}
int m=(l+r)/2;
build(v*2,l,m);build(v*2+1,m+1,r);
tr[v]=merge(tr[v*2],tr[v*2+1],0);
}
note find(int v,int l,int r,int x,int y) {
if (l==x&&r==y) return tr[v];
int m=(l+r)/2;
if (y<=m) return find(v*2,l,m,x,y);
else if (x>m) return find(v*2+1,m+1,r,x,y);
else return merge(find(v*2,l,m,x,m),find(v*2+1,m+1,r,m+1,y),0);
}
int main() {
n=get();
fo(i,1,n-1) x=get(),y=get(),z=get(),
add(x,y,z),add(y,x,z);dfs(1,0);mi[0]=1;
fo(i,1,tot) f[i][0]=i,mi[i]=mi[i-1]*2;
fo(j,1,log2(tot))
fo(i,1,tot-mi[j]+1)
if (d[dfn[f[i][j-1]]]<d[dfn[f[i+mi[j-1]][j-1]]]) f[i][j]=f[i][j-1];
else f[i][j]=f[i+mi[j-1]][j-1];
build(1,1,n);
for(m=get();m;m--) {
x=get();y=get();z=get();k=get();
note tmp=merge(find(1,1,n,x,y),find(1,1,n,z,k),1);
printf("%d\n",len(tmp.a,tmp.b));
}
}