Description
给出一棵节点数为n的树,q次查询,每次查询两点间距离
Input
第一行为两整数n和m分别表示点数和边数,之后m行每行三个整数a,b,c表示a和b之间有一条权值为c的边,之后一个字符表示这条边的方向,之后为一整数q表示查询次数,最后q行每行两个整数a和b表示查询点a到点b的距离(1<=q<=10000)
Output
对于每次查询,输出查询结果
Sample Input
7 6
1 6 13 E
6 3 9 E
3 5 7 S
4 1 3 N
2 4 20 W
4 7 2 S
3
1 6
1 4
2 6
Sample Output
13
3
36
Solution
树上两点距离dis(i,j)=dis(root,i)+dis(root,j)-2*dis(root,lca(i,j)),所以问题变成两部分,一个是求两点lca,另一个是求树上每点到根节点的距离,其中第一部分由于查询次数比较大所以可以用离线tarjan求lca,而第二部分可以在求lca的过程中一并处理
Code
#include<cstdio>
#include<iostream>
#include<algorithm>
#include<cstring>
#include<cmath>
using namespace std;
#define maxn 111111
#define maxq 111111
struct Edge
{
int to,next,c;
}edge[maxn*2];
struct Query
{
int to,next,id;
}query[maxq*2];
int f[maxn];
bool vis[maxn];
int ancestor[maxn];
int head1[maxn],tot1;
int head2[maxq],tot2;
int ans[maxq];
int dis[maxn];
void init()
{
tot1=tot2=0;
memset(head1,-1,sizeof(head1));
memset(head2,-1,sizeof(head2));
memset(vis,0,sizeof(vis));
memset(f,-1,sizeof(f));
memset(ancestor,0,sizeof(ancestor));
}
int find(int x)
{
if(f[x]==-1)return x;
return f[x]=find(f[x]);
}
void unite(int x,int y)
{
x=find(x),y=find(y);
if(x!=y)f[x]=y;
}
void add_edge(int u,int v,int c)
{
edge[tot1].to=v;
edge[tot1].c=c;
edge[tot1].next=head1[u];
head1[u]=tot1++;
}
void add_query(int u,int v,int id)
{
query[tot2].to=v;
query[tot2].next=head2[u];
query[tot2].id=id;
head2[u]=tot2++;
query[tot2].to=u;
query[tot2].next=head2[v];
query[tot2].id=id;
head2[v]=tot2++;
}
void tarjan(int u)
{
ancestor[u]=u;
vis[u]=1;
for(int i=head1[u];~i;i=edge[i].next)
{
int v=edge[i].to;
if(vis[v])continue;
dis[v]=dis[u]+edge[i].c;
tarjan(v);
unite(u,v);
ancestor[find(u)]=u;
}
for(int i=head2[u];~i;i=query[i].next)
{
int v=query[i].to;
if(vis[v])ans[query[i].id]=dis[u]+dis[v]-2*dis[ancestor[find(v)]];
}
}
int main()
{
int n,m,q;
while(~scanf("%d%d",&n,&m))
{
int u,v,c;char s[3];
init();
while(m--)
{
scanf("%d%d%d%s",&u,&v,&c,s);
add_edge(u,v,c),add_edge(v,u,c);
}
scanf("%d",&q);
for(int i=1;i<=q;i++)
{
scanf("%d%d",&u,&v);
add_query(u,v,i);
}
dis[1]=0;
tarjan(1);
for(int i=1;i<=q;i++)
printf("%d\n",ans[i]);
}
return 0;
}