给定一棵 n n n个结点的树,求解这个树中所有直径的端点。
解法:
第一遍dfs
或者bfs
从任意一个点出发,通常选择
1
1
1号点,找到距离
1
1
1号点最远的所有点
p
o
i
n
t
s
points
points。他们都是树的直径的端点。
然后从
p
o
i
n
t
s
points
points任选一个点rt
作为dfs
或者bfs
起点,找到距离点rt
最远的所有点,他们也是树的直径的端点。
由于第一遍遍历和第二遍遍历的点可能有重复:
如一个
3
3
3个点的树:
边
1
−
2
1-2
1−2
边
1
−
3
1-3
1−3
第二次选择
2
2
2号点为起点时,
3
3
3号点又会被选择一次,所以需要去重。
可以使用
s
e
t
set
set或者直接手动去重。
证明:
证明一下为什么只需要从
p
o
i
n
t
s
points
points中任选一个点跑第二次遍历即可。
第二次遍历的初始点为
u
u
u,遍历到的最远点为
v
v
v
所有
u
u
u的最近公共祖先为
r
o
o
t
root
root
那么
u
−
v
u-v
u−v一定是会经过
r
o
o
t
root
root的。
- 如果 v v v在 p o i n t s points points中,则相当于从 v v v开始跑可以跑到的最远处之一是 u u u,所以点数不会遗漏。
- 如果 v v v不在 p o i n t s points points中, p o i n t s points points中的其他点跑到的最远一定也有 v v v,相当于换起点跑,但是这些起点到 r o o t root root的距离都相同,而且一定会经过 r o o t root root到达 v v v,所以距离为: d i s ( r o o t , u ) + d i s ( r o o t , v ) dis(root,u)+dis(root,v) dis(root,u)+dis(root,v)
例题:
最深的根
例题代码:
#include<cstdio>
#include<cstring>
#include<algorithm>
#include<vector>
using namespace std;
#define sz(x) (int)x.size()
const int N = 20010, M = 20010;
int n;
int h[N], e[M], ne[M], idx;
int p[N];
int ans[N], g;
int dis[N];
void add(int a, int b) {
e[idx] = b, ne[idx] = h[a], h[a] = idx++;
}
int find(int x) {
if(x != p[x]) p[x] = find(p[x]);
return p[x];
}
int rt, mx = -1;
void dfs(int u, int fa, int dep) {
if(dep > mx) mx = dep, rt = u;
dis[u] = dep;
for(int i = h[u]; ~i; i = ne[i]) {
int v = e[i];
if(v == fa) continue;
dfs(v, u, dep + 1);
}
}
void dfs2(int u, int fa, int dep) {
mx = max(mx, dep);
dis[u] = dep;
for(int i = h[u]; ~i; i = ne[i]) {
int v = e[i];
if(v == fa) continue;
dfs2(v, u, dep + 1);
}
}
int main()
{
scanf("%d", &n);
memset(h, -1, n + 1 << 2);
int cnt = n;
for(int i = 1; i <= n; ++i) p[i] = i;
for(int i = 1; i < n; ++i) {
int a, b; scanf("%d%d", &a, &b);
if(find(a) != find(b)) {
p[find(a)] = p[find(b)];
--cnt;
}
add(a, b);
add(b, a);
}
if(cnt != 1) printf("Error: %d components\n", cnt);
else {
dfs(1, -1, 0);
for(int i = 1; i <= n; ++i)
if(dis[i] == mx) ans[++g] = i;
mx = -1;
dfs2(rt, 0, 0);
for(int i = 1; i <= n; ++i)
if(dis[i] == mx) ans[++g] = i;
sort(ans + 1, ans + g + 1);
n = unique(ans + 1, ans + g + 1) - ans - 1;
for(int i = 1; i <= n; ++i) printf("%d\n", ans[i]);
}
return 0;
}