题解:
看到多个区间询问就应该要考虑用数据结构维护了。考虑用线段树,每个区间维护两个点,表示编号在这个区间内的点构成的树的直径的两个端点,然后就可以直接合并了。
合并的时候要求
O
(
1
)
O(1)
O(1)求出两点之间的距离,要用那种rmq的LCA,否则复杂度是
O
(
n
l
o
g
2
n
)
O(nlog^2n)
O(nlog2n)的。
这个东西还可以用来解决去掉某棵子树之后的树的直径等跟直径有关的问题。
代码:
#include<bits/stdc++.h>
using namespace std;
#define LL long long
#define pa pair<int,int>
const int Maxn=100010;
const int inf=2147483647;
int read()
{
int x=0,f=1;char ch=getchar();
while(ch<'0'||ch>'9'){if(ch=='-')f=-1;ch=getchar();}
while(ch>='0'&&ch<='9')x=(x<<3)+(x<<1)+(ch^48),ch=getchar();
return x*f;
}
int n,m,dis[Maxn],dep[Maxn],a[Maxn<<1],tot=0,Log[Maxn<<1],f[Maxn<<1][18],fir[Maxn];
struct Edge{int y,d,next;}e[Maxn<<1];
int last[Maxn],len=0;
void ins(int x,int y,int d)
{
int t=++len;
e[t].y=y;e[t].d=d;e[t].next=last[x];last[x]=t;
}
void dfs(int x,int fa)
{
dep[x]=dep[fa]+1;fir[x]=++tot;f[tot][0]=x;
for(int i=last[x];i;i=e[i].next)
{
int y=e[i].y;
if(y==fa)continue;
dis[y]=dis[x]+e[i].d;dfs(y,x);f[++tot][0]=x;
}
}
int LCA(int x,int y)
{
if(fir[x]>fir[y])swap(x,y);
int t=Log[fir[y]-fir[x]+1];
int p1=f[fir[x]][t],p2=f[fir[y]-(1<<t)+1][t];
if(dep[p1]<dep[p2])return p1;
return p2;
}
int dd(int x,int y){return(LL)dis[x]+dis[y]-(dis[LCA(x,y)]<<1);}
struct Seg{int l,r,lc,rc,p1,p2;}tr[Maxn<<1];
int trlen=0;
void merge(Seg &c,Seg a,Seg b)
{
c.p1=a.p1,c.p2=a.p2;
if(dd(b.p1,b.p2)>dd(c.p1,c.p2))c.p1=b.p1,c.p2=b.p2;
if(dd(a.p1,b.p1)>dd(c.p1,c.p2))c.p1=a.p1,c.p2=b.p1;
if(dd(a.p1,b.p2)>dd(c.p1,c.p2))c.p1=a.p1,c.p2=b.p2;
if(dd(a.p2,b.p1)>dd(c.p1,c.p2))c.p1=a.p2,c.p2=b.p1;
if(dd(a.p2,b.p2)>dd(c.p1,c.p2))c.p1=a.p2,c.p2=b.p2;
}
void build(int l,int r)
{
int t=++trlen;
tr[t].l=l;tr[t].r=r;
if(l==r){tr[t].p1=tr[t].p2=l;return;}
int mid=l+r>>1;
tr[t].lc=trlen+1,build(l,mid);
tr[t].rc=trlen+1,build(mid+1,r);
merge(tr[t],tr[tr[t].lc],tr[tr[t].rc]);
}
Seg query(int x,int l,int r)
{
if(tr[x].l==l&&tr[x].r==r)return tr[x];
int lc=tr[x].lc,rc=tr[x].rc,mid=tr[x].l+tr[x].r>>1;
if(r<=mid)return query(lc,l,r);
if(l>mid)return query(rc,l,r);
Seg re;
merge(re,query(lc,l,mid),query(rc,mid+1,r));
return re;
}
int main()
{
n=read();
Log[1]=0;for(int i=2;i<=(n<<1);i++)Log[i]=Log[i>>1]+1;
for(int i=1;i<n;i++)
{
int x=read(),y=read(),d=read();
ins(x,y,d),ins(y,x,d);
}
dep[0]=-1;dis[1]=0;dfs(1,0);
for(int j=1;(1<<j)<=tot;j++)
for(int i=1;i+(1<<j)-1<=tot;i++)
{
int p1=f[i][j-1],p2=f[i+(1<<(j-1))][j-1];
if(dep[p1]<dep[p2])f[i][j]=p1;
else f[i][j]=p2;
}
build(1,n);
m=read();
while(m--)
{
int A=read(),B=read(),C=read(),D=read();
Seg t1=query(1,A,B),t2=query(1,C,D),t;
printf("%d\n",max(max(dd(t1.p1,t2.p1),dd(t1.p1,t2.p2)),max(dd(t1.p2,t2.p1),dd(t1.p2,t2.p2))));
}
}