Description
n个点被n-1条边连接成了一颗树,给出a~b和c~d两个区间,表示点的标号请你求出两个区间内各选一点之间的最大距离,即你需要求出max{dis(i,j) |a<=i<=b,c<=j<=d}
(PS 建议使用读入优化)
Input
第一行一个数字 n n<=100000。
第二行到第n行每行三个数字描述路的情况, x,y,z (1<=x,y<=n,1<=z<=10000)表示x和y之间有一条长度为z的路。
第n+1行一个数字m,表示询问次数 m<=100000。
接下来m行,每行四个数a,b,c,d。
Output
共m行,表示每次询问的最远距离
Input示例
5
1 2 1
2 3 2
1 4 3
4 5 4
1
2 3 4 5
Output示例
10
分析
考虑用线段树维护区间的最远点对。
合并两个区间:两个区间分别有两个点对,一共四个点,两两求一遍距离取最大的点对即可(和合并两个树时更新树的直径一样)。
注意常数优化。
#include <cstdio>
#include <cstring>
#include <algorithm>
#include <cmath>
using namespace std;
const int maxn=100005,maxm=200005,Log=17,maxt=262205;
int n,q,tot,rmq[maxm][Log+1],dep[maxn],size[maxn],h[maxn],e[maxm],next[maxm],la[maxn],fa[maxn],len[maxm],dis[maxn];
char c;
struct data
{
int p,q,dis;
}t[maxt];
int read()
{
for (c=getchar();c<'0' || c>'9';c=getchar());
int x=c-48;
for (c=getchar();c>='0' && c<='9';c=getchar()) x=x*10+c-48;
return x;
}
void add(int x,int y,int l)
{
e[++tot]=y; next[tot]=h[x]; h[x]=tot; len[tot]=l;
}
void dfs(int x)
{
dep[x]=dep[fa[x]]+1;
rmq[la[x]=tot++][0]=x;
for (int i=h[x];i;i=next[i]) if (e[i]!=fa[x])
{
fa[e[i]]=x; dis[e[i]]=dis[x]+len[i];
dfs(e[i]);
size[x]+=size[e[i]]+1;
rmq[la[x]=tot++][0]=x;
}
}
int getlca(int x,int y)
{
if (la[x]>la[y]) x^=y^=x^=y;
x=la[x]; y=la[y];
int k=log(y-x+1)/log(2);
return (dep[rmq[x][k]]<=dep[rmq[y-(1<<k)+1][k]])?rmq[x][k]:rmq[y-(1<<k)+1][k];
}
int getdis(int x,int y)
{
return dis[x]+dis[y]-2*dis[getlca(x,y)];
}
data cmp(data a,data b)
{
data t=a;
if (t.dis<b.dis) t=b;
int d1=getdis(a.p,b.p),d2=getdis(a.p,b.q),d3=getdis(a.q,b.p),d4=getdis(a.q,b.q);
if (t.dis<d1)
{
t.p=a.p; t.q=b.p; t.dis=d1;
}
if (t.dis<d2)
{
t.p=a.p; t.q=b.q; t.dis=d2;
}
if (t.dis<d3)
{
t.p=a.q; t.q=b.p; t.dis=d3;
}
if (t.dis<d4)
{
t.p=a.q; t.q=b.q; t.dis=d4;
}
return t;
}
void init(int l,int r,int x)
{
if (l==r)
{
t[x].p=t[x].q=l;
return;
}
int mid=(l+r)/2;
init(l,mid,x*2); init(mid+1,r,x*2+1);
t[x]=cmp(t[x*2],t[x*2+1]);
}
data getmax(int l,int r,int a,int b,int x)
{
if (l==a && r==b) return t[x];
int mid=(l+r)/2;
if (b<=mid) return getmax(l,mid,a,b,x*2);
if (a>mid) return getmax(mid+1,r,a,b,x*2+1);
return cmp(getmax(l,mid,a,mid,x*2),getmax(mid+1,r,mid+1,b,x*2+1));
}
int main()
{
n=read();
for (int i=1;i<n;i++)
{
int x=read(),y=read(),l=read();
add(x,y,l); add(y,x,l);
}
tot=0;
dfs(1);
for (int j=1;j<=Log;j++)
for (int i=0;i<=tot-(1<<j);i++)
rmq[i][j]=(dep[rmq[i][j-1]]<=dep[rmq[i+(1<<(j-1))][j-1]])?rmq[i][j-1]:rmq[i+(1<<(j-1))][j-1];
init(1,n,1);
q=read();
while (q--)
{
int x1=read(),y1=read(),x2=read(),y2=read();
data a=getmax(1,n,x1,y1,1),b=getmax(1,n,x2,y2,1),t;
t.p=a.p; t.q=b.p; t.dis=getdis(t.p,t.q);
int d2=getdis(a.p,b.q),d3=getdis(a.q,b.p),d4=getdis(a.q,b.q);
if (t.dis<d2)
{
t.p=a.p; t.q=b.q; t.dis=d2;
}
if (t.dis<d3)
{
t.p=a.q; t.q=b.p; t.dis=d3;
}
if (t.dis<d4)
{
t.p=a.q; t.q=b.q; t.dis=d4;
}
printf("%d\n",t.dis);
}
return 0;
}