题目描述
给出 n 个点的一棵树,多次询问两点之间的最短距离。
注意:
边是无向的。
所有节点的编号是 1,2,…,n。
输入格式
第一行为两个整数 n 和 m。n 表示点数,m 表示询问次数;
下来 n−1 行,每行三个整数 x,y,k,表示点 x 和点 y 之间存在一条边长度为 k;
再接下来 m 行,每行两个整数 x,y,表示询问点 x 到点 y 的最短距离。
树中结点编号从 1 到 n。
输出格式
共 m 行,对于每次询问,输出一行询问结果。
数据范围
2≤n≤104,
1≤m≤2×104,
0<k≤100,
1≤x,y≤n
输入样例1
2 2
1 2 100
1 2
2 1
输出样例1
100
100
输入样例2
3 2
1 2 10
3 1 15
1 2
3 2
输出样例2
10
25
题目分析
这是一道求最近公共祖先的模板题。
tarjan算法求最近公共祖先: O(n+m) 离线算法
在深度优先遍历时,将所有点分成三大类:
(2号点集)表示已经访问过且已经回溯完的点
(1号点集)表示正在进行访问的节点
(0号点集)表示已经访问结束的点
我们可以在遍历1号点集中的点的时候,将2号点集中的点用并查集合并到其1号点集中的父节点上去。
如果有一对查询a和b,a点在1号点集中,b点在2号点集中,那么它们的最近公共祖先即为b节点所在连通块的根节点。
然后我们再回过头来看看这道题:我们可以用dist[i]表示i节点到根节点的距离。
a节点到b节点的距离即为dist[a]+dist[b]-dist[p]*2 //p为a和b的最近公共祖先
代码如下
#include <iostream>
#include <cstdio>
#include <cmath>
#include <string>
#include <cstring>
#include <map>
#include <queue>
#include <stack>
#include <vector>
#include <set>
#include <algorithm>
#define LL long long
#define ULL unsigned long long
#define PII pair<int,int>
#define x first
#define y second
using namespace std;
const int N=1e4+5,INF=0x3f3f3f3f;
int h[N],e[N*2],w[N*2],ne[N*2],idx;
int dist[N],p[N];
int st[N],ans[N*2];
vector<PII> query[N];
void add(int a,int b,int c) //加边函数
{
e[idx]=b;
w[idx]=c;
ne[idx]=h[a];
h[a]=idx++;
}
int find(int x) //并查集模板函数
{
if(p[x]!=x) p[x]=find(p[x]);
return p[x];
}
void dfs(int u,int fa) //dfs预处理出dist[]数组
{
for(int i=h[u];~i;i=ne[i])
{
int v=e[i];
if(v==fa) continue;
dist[v]=dist[u]+w[i];
dfs(v,u);
}
}
void tarjan(int u)
{
st[u]=1; //遍历开始,将当前节点归入1号点集
for(int i=h[u];~i;i=ne[i]) //遍历其所有子节点
{
int v=e[i];
if(!st[v]) //如果v是u的子节点(即v属于0号点集)
{
tarjan(v); //递归遍历
p[v]=u; //其v归入以u为根的连通块中
}
}
for(PII t:query[u]) //遍历有关u节点的所有的查询
if(st[t.x]==2) //如果另一个点已被遍历过
{
int anc=find(t.x); //则另一个点所在的连通块的根节点即为其最近公共祖先
ans[t.y]=dist[u]+dist[t.x]-dist[anc]*2; //计算答案
}
st[u]=2; //遍历结束,将u归入2号点集
}
int main()
{
for(int i=1;i<N;i++) p[i]=i;
memset(h,-1,sizeof h);
int n,m;
scanf("%d %d",&n,&m);
for(int i=1;i<n;i++) //建图
{
int u,v,w;
scanf("%d %d %d",&u,&v,&w);
add(u,v,w); add(v,u,w);
}
for(int i=0;i<m;i++) //记录所有的查询
{
int a,b;
scanf("%d %d",&a,&b);
if(a!=b)
{
query[a].push_back({b,i});
query[b].push_back({a,i});
}
}
dfs(1,-1);
tarjan(1);
for(int i=0;i<m;i++) //输出答案
printf("%d\n",ans[i]);
return 0;
}