树的直径
Q:由n个结点组成的一棵树,求树上最长的路径(树的直径)。(路径上结点数之和)
在学会如何写代码之前,我们要先了解一下树的直径的性质。
- 1.直径的两端点一定是两个叶子节点。
- 2.距离任意点最远的点一定是直径的一个端点。
让我们来证明一下上面的两个结论。
命题1:直径的两端点一定是两个叶子节点
我们这里采用反证法,如果直径的两个端点不是叶子结点,那么必然这个节点一定会有孩子节点,那么这样的路程又可以增加一节,所以原直径并不是这棵树的直径,矛盾。
所以直径的两端点一定是叶子结点
命题2:距离任意点最远的点一定是直径的一个端点
我们这里同样采用反证法,如果距离任意点最远的点不是直径的一个端点,那么这里就有两种情况:
我们设当前直径为xy,现有任意点O和另一个点M情况1:如果O在直径xy上
∵ O M > O X 或 O M > O Y \because OM>OX或OM>OY ∵OM>OX或OM>OY
∴ d = O X + O M 或 O Y + O M \therefore d=OX+OM或OY+OM ∴d=OX+OM或OY+OM
与条件不符。情况2:如果O不在直径xy上
∵ O M > O X 或 O M > O Y \because OM>OX或OM>OY ∵OM>OX或OM>OY
∴ d = O X + O M + X Y 或 O Y + O M + X Y \therefore d=OX+OM+XY或OY+OM+XY ∴d=OX+OM+XY或OY+OM+XY
与条件不符综上,情况1,2皆不符和题意,所以矛盾,所以距离任意点最远的点一定是直径的一个端点
知道了这两个结论,我们就可以来思考怎么用程序来写了。
首先可以想到,如果我们从任意一个点出发,达到距离它距离最远的点
n
n
n,
n
n
n就一定是直径的一端,此时再从
n
n
n出发,达到距离
n
n
n最远的点
m
m
m,
m
m
m就一定也是直径的一端,所以
n
m
nm
nm就是这棵树的直径,基于此,我们就可以采用两次dfs来求得树的直径
#include<bits/stdc++.h>
using namespace std;
vector<int> mp[1000010];
int dis[1000010],st;
void dfs(int x,int fa){
for (int i=0;i<mp[x].size();i++){
if (mp[x][i]!=fa){
dis[mp[x][i]]=dis[x]+1;//距离增加
if (dis[st]<dis[mp[x][i]])st=mp[x][i];//更新最远的点
dfs(mp[x][i],x);
}
}
}
int main(){
int n;
cin>>n;
for (int i=1;i<n;i++){
int u,v;
cin>>u>>v;
mp[u].push_back(v);
mp[v].push_back(u);
}
dfs(1,0);
dis[st]=0;
dfs(st,0);//两次dfs
cout<<dis[st]+1;
return 0;
}
求所有点的最远距离
Q:给你一棵 N(N<=500000)个节点的树,求每个点到其他点的最大距离。
这个问题有一个极为朴素的做法,我们可以去遍历每一个点,找到每一个点最大距离,这样下来的时间复杂度是
O
(
n
×
(
n
+
m
)
)
O(n\times (n+m))
O(n×(n+m)),有点高,那么有没有什么办法可以减小时间复杂度的呢?
刚刚我们知道了树的直径的性质,现在我们就可来利用这些性质。我们知道,距离任意点最远的点一定是直径的一个端点,从这句话我们可以得出,从端点到任意点的距离一定是最长的,所以我们这里就可以从两个端点出发,遍历完所有的点,然后比较两条路径哪个长就可以了
#include<bits/stdc++.h>
using namespace std;
vector<int> mp[500010];
int dis1[500010],dis2[500010],st,ed;
void dfs1(int x,int fa){
for (int i=0;i<mp[x].size();i++){
if (mp[x][i]!=fa){
dis1[mp[x][i]]=dis1[x]+1;
if (dis1[st]<dis1[mp[x][i]])st=mp[x][i];
dfs1(mp[x][i],x);
}
}
}
void dfs2(int x,int fa){
for (int i=0;i<mp[x].size();i++){
if (mp[x][i]!=fa){
dis1[mp[x][i]]=dis1[x]+1;
dfs2(mp[x][i],x);
}
}
}
void dfs3(int x,int fa){
for (int i=0;i<mp[x].size();i++){
if (mp[x][i]!=fa){
dis2[mp[x][i]]=dis2[x]+1;
dfs3(mp[x][i],x);
}
}
}
int main(){
int n;
scanf("%d",&n);
for (int i=1;i<n;i++){
int u,v;
scanf("%d %d",&u,&v);
mp[u].push_back(v);
mp[v].push_back(u);
}
dfs1(1,0);
ed=st;
dis1[st]=0;
dfs1(st,0);
memset(dis1,0,sizeof(dis1));
dfs2(st,0);
dfs3(ed,0);
for (int i=1;i<=n;i++){
printf("%d\n",max(dis1[i],dis2[i]));
}
return 0;
}