线段树维护区间最长路径
对于一次合并
(a,b)+(c,d)
(
a
,
b
)
+
(
c
,
d
)
,若
(a,b)
(
a
,
b
)
中最长线段为
(x1,x2)
(
x
1
,
x
2
)
,
(c,d)
(
c
,
d
)
中最长线段为
(y1,y2)
(
y
1
,
y
2
)
,那么合并起来最长线段就是
(x1,x2),(x1,y1),(x1,y2),(x2,y1),(x2,y2),(y1,y2)
(
x
1
,
x
2
)
,
(
x
1
,
y
1
)
,
(
x
1
,
y
2
)
,
(
x
2
,
y
1
)
,
(
x
2
,
y
2
)
,
(
y
1
,
y
2
)
其中一个(好像可以化简成四个…并不是很理解【捂脸】)
求距离的时候 x x 到 的距离就是 dis(x)+dis(y)−2∗dis(lca(x,y)) d i s ( x ) + d i s ( y ) − 2 ∗ d i s ( l c a ( x , y ) )
至于 lca l c a ,之间倍增的话会 t ,因此使用 st s t 表优化
代码如下~
#include <cstdio>
#include <cstring>
#include <algorithm>
using namespace std;
#define N 100010
int n,m,num=0,h[N],d[N],tot=0,dsum[N],log[N<<1],id[N],st[N<<2][21];
struct node{int x,y,z,next;}mp[N<<1];
inline int lca(int x,int y){
if(x>y) swap(x,y);
int len=log[y-x+1];
if(d[st[x][len]]<d[st[y-(1<<len)+1][len]]) return st[x][len];
else return st[y-(1<<len)+1][len];
}
inline int dist(int x,int y){
return dsum[x]+dsum[y]-2*dsum[lca(id[x],id[y])];
}
struct node1{
int l,r,s;
}t[N<<2];
bool operator<(node1 z1,node1 z2){return z1.s<z2.s;}
node1 operator+(node1 z1,node1 z2){
int s1=-1,s2=-1;
int s3=dist(z1.l,z2.l);
int s4=dist(z1.l,z2.r);
int s5=dist(z1.r,z2.l);
int s6=dist(z1.r,z2.r);
if(s1>=s2 && s1>=s3 && s1>=s4 && s1>=s5 && s1>=s6) return (node1){z1.l,z1.r,s1};
if(s2>=s1 && s2>=s3 && s2>=s4 && s2>=s5 && s2>=s6) return (node1){z2.l,z2.r,s2};
if(s3>=s2 && s3>=s1 && s3>=s4 && s3>=s5 && s3>=s6) return (node1){z1.l,z2.l,s3};
if(s4>=s2 && s4>=s3 && s4>=s1 && s4>=s5 && s4>=s6) return (node1){z1.l,z2.r,s4};
if(s5>=s2 && s5>=s3 && s5>=s4 && s5>=s1 && s5>=s6) return (node1){z1.r,z2.l,s5};
if(s6>=s2 && s6>=s3 && s6>=s4 && s6>=s1 && s6>=s5) return (node1){z1.r,z2.r,s6};
}
inline char gc(){
static char buf[1<<16],*S,*T;
if(T==S){T=(S=buf)+fread(buf,1,1<<16,stdin);if(T==S) return EOF;}
return *S++;
}
inline int read(){
int x=0,f=1;char ch=gc();
while(ch<'0'||ch>'9'){if(ch=='-') f=-1;ch=gc();}
while(ch>='0'&&ch<='9') x=x*10+ch-'0',ch=gc();
return x*f;
}
inline void insert(int x,int y,int z){
mp[++num].x=x;mp[num].y=y;mp[num].z=z;mp[num].next=h[x];h[x]=num;
mp[++num].x=y;mp[num].y=x;mp[num].z=z;mp[num].next=h[y];h[y]=num;
}
void dfs(int x){
id[x]=++tot;st[tot][0]=x;
for(int i=h[x];i;i=mp[i].next){
int y=mp[i].y;if(d[y]) continue;
d[y]=d[x]+1;dsum[y]=dsum[x]+mp[i].z;
dfs(y);st[++tot][0]=x;
}
}
void build(int v,int l,int r){
if(l==r){
t[v].l=l;t[v].r=r;t[v].s=0;
return;
}int mid=l+r>>1;
build(v<<1,l,mid);build(v<<1|1,mid+1,r);
t[v]=max(t[v<<1],max(t[v<<1|1],t[v<<1]+t[v<<1|1]));
}
node1 query(int v,int l,int r,int x,int y){
if(x<=l && r<=y) return t[v];
int mid=l+r>>1;
if(y<=mid) return query(v<<1,l,mid,x,y);
else if(x>mid) return query(v<<1|1,mid+1,r,x,y);
else{
node1 x1=query(v<<1,l,mid,x,mid);
node1 x2=query(v<<1|1,mid+1,r,mid+1,y);
return max(x1+x2,max(x1,x2));
}
}
void getrmq(){
for(int i=1;i<=log[tot];++i){
for(int j=1;j<=tot+1-(1<<i);++j)
if(d[st[j][i-1]]<d[st[j+(1<<i-1)][i-1]]) st[j][i]=st[j][i-1];
else st[j][i]=st[j+(1<<i-1)][i-1];
}
}
int main(){
n=read();
log[0]=-1;for(int i=1;i<2*N;++i) log[i]=log[i>>1]+1;
for(int i=1;i<n;i++){
int x=read(),y=read(),z=read();
insert(x,y,z);
}
d[1]=1;dfs(1);
getrmq();
build(1,1,n);
m=read();
for(int i=1;i<=m;i++){
int x1=read(),x2=read(),y1=read(),y2=read();
node1 ans=query(1,1,n,x1,x2)+query(1,1,n,y1,y2);
printf("%d\n",ans.s);
}
return 0;
}