树的深度优先遍历框架
时间复杂度为:O(n + m)
(n
为点数,m
为边数)
使用一个bool
数组记录每个节点的遍历情况,防止重复遍历
void dfs(int u)
{
st[u]=true; // 标记一下,记录为已经被搜索过了,下面进行搜索过程
for(int i=h[u]; ~i; i=ne[i])
{
int j = e[i];
if(!st[j]) dfs(j);
}
}
题意:
重心定义:重心是指树中的一个结点,如果将这个点删除后,剩余各个连通块中点数的最大值最小
那么这个节点被称为树的重心。
给定一颗树,树中包含 n
个结点(编号 1~n
)和 n - 1
条无向边。
请你找到树的重心,并输出将重心删除后,剩余各个连通块中点数的最大值。
举个例子方便理解,我们就拿题目给出的样例构建出一棵树:
我们依次枚举将各点删除后,剩余每个连通块中节点数的最大值
比如我们将 1
号根节点删掉,剩下了 3
个连通块,其中连通块中节点数最大值为 4
(包含4、3、6、9六个节点)
如果我们将 2
号节点删去,同样剩余 3
个连通块,其中连通块中节点数最大值为 6
(包含1、4、7、3、6、9
六个节点)
同样的道理,我们将 4
号节点删去,也剩余 3
个连通块,其中连通块中节点数最大值为 5
(包含1、2、7、8、5
五个节点)
其他的节点这里就不作分析了,
我们的目标是求出并输出“最小的最大值”,经过树中所有点的枚举,我们可以得出最优解就是 4
,即 将树中的重心删除后,剩余各连通块中点数最大值为 4
思路:
只要对于树中每个节点都能求出 “把这个点删除后,剩余各连通块点数的最大值”,之后在所有值中求得最小值,就是我们目标的答案。
那么如何快速求得将每个点删去后的剩余各连通块点数的最大值呢?
答案:运用树的深度优先遍历。
如上图,当计算 4
这棵子树的大小时,先递归处理完 3、6
两棵子树,
“4
往下遍历的过程当中,可以统计出来 3、6
两棵子树”,
则 3
子树大小 + 6
子树大小 + 4
号节点本身 即为 4
这棵子树总大小
删去 4
之后,剩余 3
个连通块:{3,9}、{6}、{1,2,7,8,5}
,其中第 1、2
个连通块我们称为“子节点部分”,第 3
个连通块我们称为“父节点及以上部分”。
对于“子节点部分”我们可以通过dfs
在回溯的时候返回,我们设它们的总和为size(sons)
,
对于“父节点及以上部分”可以通过 总点数 n
- 4
号子树大小 算出,其中 4
号子树大小 = size(sons) + 1
(4
号点本身也算 1
个)
总结:深度优先遍历的过程当中,递归每一个节点的时候,我们都可以计算出来:将这个点删去之后,其余各个连通块中点的数量(大小),这是个很有用的性质
这份代码能够计算出树上每个节点对应的子树大小,并将其存入sz
数组并输出:dfs深度优先遍历 求树上各节点代表子树大小
我们每次找到剩余连通块中点数的最大值,并和答案取 min
,不断迭代更新,就可以得到最终的答案 ans
。
时间复杂度:
O(n + m)
(n
为点数,m
为边数)
代码:
#include<bits/stdc++.h>
using namespace std;
const int N = 1e5+10;
int n;
bool st[N];
int h[N], e[N<<1], ne[N<<1], idx, sz[N];
int ans; //设置一个全局的答案,存储“最小的最大值”
void add(int a, int b)
{
e[idx] = b, ne[idx] = h[a], h[a] = idx++;
}
int dfs(int u) //返回以 u 为子树的大小
{
int res = 0; //存储将该点删去之后,剩余各连通块大小的最大值
sz[u] = 1; //当前点算 1 个点,所以从 1 开始
st[u] = true;
for(int i=h[u]; ~i; i=ne[i])
{
int j = e[i];
if(!st[j])
{
int s = dfs(j); //子树 j 大小
res = max(res, s); //子树 j 是个连通块,将它的大小与 res 取 max
sz[u] += s; //子树 j 是子树 u 的一部分,因此 sz[u] 要加上子树 j 大小
}
}
res = max(res, n-sz[u]); //n-sz[u] 即上文提到的节点 u “父节点及以上部分” 的大小,也是个连通块,用其大小更新 res
ans = min(ans, res); //更新全局答案
return sz[u];
}
int main()
{
cin>>n;
ans = n;
memset(h, -1, sizeof h);
for(int i=0; i<n-1; ++i)
{
int a, b;
scanf("%d%d", &a, &b);
add(a, b), add(b, a); //无向边
}
dfs(1);
printf("%d\n", ans);
return 0;
}