题目地址:
https://www.acwing.com/problem/content/1173/
给出 n n n个点的一棵树,多次询问两点之间的最短距离。
注意:边是无向的。所有节点的编号是 1 , 2 , … , n 1,2,…,n 1,2,…,n。
输入格式:
第一行为两个整数
n
n
n和
m
m
m。
n
n
n表示点数,
m
m
m表示询问次数;下来
n
−
1
n−1
n−1行,每行三个整数
x
,
y
,
k
x,y,k
x,y,k,表示点
x
x
x和点
y
y
y之间存在一条边长度为
k
k
k;再接下来
m
m
m行,每行两个整数
x
,
y
x,y
x,y,表示询问点
x
x
x到点
y
y
y的最短距离。树中结点编号从
1
1
1到
n
n
n。
输出格式:
共
m
m
m行,对于每次询问,输出一行询问结果。
数据范围:
2
≤
n
≤
1
0
4
2≤n≤10^4
2≤n≤104
1
≤
m
≤
2
×
1
0
4
1≤m≤2×10^4
1≤m≤2×104
0
<
k
≤
100
0<k≤100
0<k≤100
1
≤
x
,
y
≤
n
1≤x,y≤n
1≤x,y≤n
这一题可以离线地做,即已知所有的询问,然后设计一种方法一次性地回答所有询问。在规定任意一个点为树根之后,询问两个点的最短距离,可以先DFS一下预处理出每个点到树根的距离的数组 d d d,那么对于 x , y x,y x,y之间的最短距离,设其最近公共祖先是 u u u,那么最短距离就是 d [ x ] + d [ y ] − 2 d [ u ] d[x]+d[y]-2d[u] d[x]+d[y]−2d[u]。那么问题转化为离线求两个点的最近公共祖先,可以用Tarjan算法。
其思想是基于DFS和并查集。考虑一棵DFS树,在遍历到当前节点时,我们把所有的节点分为三个部分,还未访问的点标记为 0 0 0,还有孩子没访问过的点记为 1 1 1(即从树根走到当前节点的链),回溯过的点记为 2 2 2。每次遍历完一个分支之后,将这个分支的子树根在并查集里接到当前节点上。当一条链走到底的时候,从链上开始回溯,此时每次回溯之前都遍历所有标记为 2 2 2的点 u u u,这些点与当前点 v v v的最近公共祖先,其实就是 u u u在并查集里的祖宗节点(因为每次DFS完一个分支之后,那个分支里所有的点在并查集里的祖宗都是那个分支的子树根)。从链上回溯之前,都标记一下当前点的状态是 2 2 2。所以,在一条链上开始回溯的时候,就同时计算出了含当前点的所有询问的结果。代码如下:
#include <iostream>
#include <cstring>
#include <unordered_map>
#include <vector>
using namespace std;
typedef pair<int, int> PII;
const int N = 20010, M = N * 2;
int n, m;
int h[N], e[M], w[M], ne[M], idx;
int dist[N];
int p[N];
int res[N];
int st[N];
// query[x]是含x这个节点的所有询问的信息,
// vector里面的pair的first是询问的另一个点,second是询问的下标
unordered_map<int, vector<PII> > query;
void add(int a, int b, int c) {
e[idx] = b, ne[idx] = h[a], w[idx] = c, h[a] = idx++;
}
// 以u为树根做DFS,求一下每个点到u的距离
void dfs(int u, int father) {
for (int i = h[u]; ~i; i = ne[i]) {
int j = e[i];
// 略过父亲
if (j == father) continue;
dist[j] = dist[u] + w[i];
dfs(j, u);
}
}
int find(int x) {
if (p[x] != x) p[x] = find(p[x]);
return p[x];
}
// 从u出发开始DFS,
void tarjan(int u) {
st[u] = 1;
for (int i = h[u]; ~i; i = ne[i]) {
int j = e[i];
if (!st[j]) {
tarjan(j);
p[j] = u;
}
}
// 走到一条链的最后一个节点了(即走到叶子了),开始回答询问;
// 找一下所有被标记为2的点,开始回答含这些点的所有询问
for (auto item : query[u]) {
int y = item.first, id = item.second;
// 略过已经回答过的询问
if (res[id]) continue;
if (st[y] == 2) {
int anc = find(y);
res[id] = dist[u] + dist[y] - dist[anc] * 2;
}
}
// 回溯之前标记一下当前点状态为2
st[u] = 2;
}
int main() {
scanf("%d%d", &n, &m);
memset(h, -1, sizeof h);
for (int i = 0; i < n - 1; i++) {
int a, b, c;
scanf("%d%d%d", &a, &b, &c);
add(a, b, c), add(b, a, c);
}
for (int i = 0; i < m; i++) {
int a, b;
scanf("%d%d", &a, &b);
if (a != b) {
if (!query.count(a)) query[a] = vector<PII>();
if (!query.count(b)) query[b] = vector<PII>();
query[a].push_back({b, i});
query[b].push_back({a, i});
}
}
// 初始化并查集
for (int i = 1; i <= n; i++) p[i] = i;
// 以1为根,预处理一下每个点到1的距离
dfs(1, -1);
tarjan(1);
for (int i = 0; i < m; i++) printf("%d\n", res[i]);
return 0;
}
时间复杂度 O ( n + m log ∗ n ) O(n+m\log ^*n) O(n+mlog∗n),空间 O ( n ) O(n) O(n)。