题目描述
给出 n 个点的一棵树,多次询问两点之间的最短距离。
注意:
边是无向的。
所有节点的编号是 1,2,…,n。
输入格式
第一行为两个整数 n 和 m。n 表示点数,m 表示询问次数;
下来 n−1 行,每行三个整数 x,y,k,表示点 x 和点 y 之间存在一条边长度为 k;
再接下来 m 行,每行两个整数 x,y,表示询问点 x 到点 y 的最短距离。
树中结点编号从 1 到 n。
输出格式
共 m 行,对于每次询问,输出一行询问结果。
数据范围
2≤n≤104,
1≤m≤2×104,
0<k≤100,
1≤x,y≤n
输入样例1:
2 2
1 2 100
1 2
2 1
输出样例1:
100
100
输入样例2:
3 2
1 2 10
3 1 15
1 2
3 2
输出样例2:
10
25
思路:
使用tarjan算法。
将所有的点通过st数组分成三类,分别为0,1,2;
- 未涉及到的点
- 正在访问的节点
- 已经访问过的节点(此时,该节点已通过并查集合并到最近公共祖先节点所在集合,最近公共祖先为代表元素)
首先先用bfs初始化每个节点到根节点的距离。
使用tarjan算法:
- 将正在搜索的节点即root标记为1
- 本质上为dfs,去dfs所有的邻接点。此步骤完成后,root的子树节点均已被标记为2。
- 去遍历与root节点所有相关的查询。
- 若另外一个节点已被标记为2(即已被归并到共同节点)那么最近公共祖先就为find
- 否则跳过
代码:
#include<iostream>
#include<vector>
#include<cstring>
using namespace std;
typedef pair<int, int>PII;
const int N = 1e6 + 10;
int fa[N];
int e[N], ne[N], w[N], h[N],idx = 0;
int res[N];
int dist[N];
int st[N];
vector<PII>q[N];
void add(int a, int b, int c)
{
w[idx] = c;
e[idx] = b;
ne[idx] = h[a];
h[a] = idx++;
}
int find(int x)
{
if (fa[x] != x) fa[x] = find(fa[x]);
return fa[x];
}
void bfs(int u,int Fa)
{
for (int i = h[u]; i != -1; i = ne[i])
{
int j = e[i];
if (j != Fa)
{
dist[j] = dist[u] + w[i];
bfs(j,u);
}
}
}
void tarjan(int root)
{
st[root] = 1;
for (int i = h[root]; i != -1; i = ne[i])
{
int j = e[i];
if (st[j] == 0)
{
tarjan(j);
fa[j] = root;
}
}
for (auto p : q[root])
{
int a = p.first;
int b = p.second;
if (st[a] == 2)
{
int anc = find(a);
res[b] = dist[root] + dist[a] - dist[anc] * 2;
}
}
st[root] = 2;
}
int main()
{
memset(st, 0, sizeof st);
memset(h, -1, sizeof h);
int n;
cin >> n;
int m;
cin >> m;
for (int i = 0; i < N; i++)
{
fa[i] = i;
}
for (int i = 0; i < n - 1; i++)
{
int a, b, c;
cin >> a >> b >> c;
add(a, b, c);
add(b, a, c);
}
for (int i = 0; i < m; i++)
{
int a, b;
cin >> a >> b;
if (a != b)
{
q[a].push_back({ b, i });
q[b].push_back({ a, i });
}
}
bfs(1,-1);
tarjan(1);
for (int i = 0; i < m; i++)
{
cout << res[i] << endl;
}
return 0;
}
附:倍增算法
#include<iostream>
#include<queue>
#include<cstring>
using namespace std;
const int N = 1e6 + 10;
int e[N], ne[N], w[N], h[N],idx = 0;
int depth[N];
int dist[N];
int fa[N][17];
void add(int a, int b,int c)
{
w[idx] = c;
e[idx] = b;
ne[idx] = h[a];
h[a] = idx++;
}
void bfs(int root)
{
dist[root] = 0;
depth[root] = 1;
depth[0] = 0;
queue<int>q;
q.push(root);
while (!q.empty())
{
int t = q.front();
q.pop();
for (int i = h[t]; i != -1; i = ne[i])
{
int j = e[i];
if (depth[j] > depth[t] + 1)
{
dist[j] = dist[t] + w[i];
depth[j] = depth[t] + 1;
q.push(j);
fa[j][0] = t;
for (int k = 1; k < 16; k++)
{
fa[j][k] = fa[fa[j][k - 1]][k - 1];
}
}
}
}
}
int lca(int a,int b)
{
if (depth[a] < depth[b])
{
swap(a, b);
}
for (int k = 16; k >= 0; k--)
{
if (depth[fa[a][k]] >= depth[b])
{
a = fa[a][k];
}
}
if (a == b)
{
return a;
}
for (int k = 16; k >= 0; k--)
{
if (fa[a][k] != fa[b][k])
{
a = fa[a][k];
b = fa[b][k];
}
}
return fa[a][0];
}
int main()
{
memset(h, -1, sizeof h);
memset(depth, 0x3f3f3f3f, sizeof depth);
memset(dist, 0x3f3f3f3f, sizeof dist);
int n, m;
cin >> n >> m;
for (int i = 0; i < n - 1; i++)
{
int a, b,c;
cin >> a >> b>>c;
add(a, b,c);
add(b, a,c);
}
bfs(1);
for (int i = 0; i < m; i++)
{
int a, b;
cin >> a >> b;
int anc=lca(a, b);
cout << dist[a] + dist[b] - dist[anc] * 2<<endl;
}
}