题目
题解
本质上是一个计算树的直径的问题。(模板题)
树的直径有两种计算方式:
- 两次dfs。以任意一点
X
为根节点,第一次dfs找到距离X
最远的节点P
,第二次dfs找到距离节点P
最远的节点Q
,P
和Q
的距离就是树的直径。 - 树型DP。两个数组
mx[i]
,_mx[i]
分别记录当前节点i
到根节点(随意选取)的最大距离和次最大距离。假设j
为i
的子结点,更新方式:如果mx[i] <= mx[j] + cost[i][j]
,则_mx[i] = mx[i]
,mx[i] = mx[j] + cost[i][j]
;否则,如果_mx[i] < mx[j] + cost[i][j]
,则_mx[i] = mx[j] + cost[i][j]
。
代码
两次dfs
#include<bits/stdc++.h>
using namespace std;
const int N = 1e5+10;
int idx, n, u, v, w;
int h[N], ne[N<<2], e[N<<2];
int cost[N<<2]; // cost[i] 表示第i条边的权重
int d[N]; // d[i]表示从节点i到根节点的距离(至于哪个是根节点,不同情况不一样)
int p, ans; // p表示暂时找到中转点,ans当前搜索情况下的最长距离
void add (int a, int b, int w) { // 邻接表添加边
cost[idx] = w, e[idx] = b, ne[idx] = h[a], h[a] = idx ++;
}
void dfs (int x, int fa) {
if (ans < d[x]) { // x节点到根节点的距离比最大距离大,则更新最大距离
ans = d[x];
p = x; // 记录距离根节点最远的节点
}
for (int i = h[x];~i;i = ne[i]) { // 注意 ~
int j = e[i]; // 根据边的编号得到该边对应的终点的编号
if (j == fa) continue; // 如果回头了,则continue
d[j] = d[x] + cost[i]; // 更新子结点 j 到根节点的距离(因为路径唯一,所以可以这样更新)
dfs (j, x); // 递归遍历 j 节点的子结点
}
}
void find (int x) {
ans = 0; // 初始化 ans 为 0
d[x] = 0; // 根节点到根节点的距离记为 0
dfs (x, 0);
}
int main()
{
memset (h, -1, sizeof h); // 初始化
cin >> n;
for (int i = 1;i < n;i ++) {
cin >> u >> v >> w;
add (u, v, w);
add (v, u, w);
}
find (1); // 随便选一个点作为根节点,找距离其最远的节点
find (p); // 将上面找到的 p 作为根节点,找距离 p 最远的节点,该距离为树的直径
cout << ans * 10 + ans * (ans + 1) / 2 << endl; // 公式输出
return 0;
}
树型 DP
// 计算树的直径,树型DP
#include<bits/stdc++.h>
using namespace std;
const int N = 1e5+10;
int n, m, u, v, w, idx;
int e[N<<2], ne[N<<2], h[N], cost[N<<2];
int ans, _mx[N], mx[N];
void add (int a, int b, int w) {
cost[idx] = w, e[idx] = b, ne[idx] = h[a], h[a] = idx ++;
}
void dp (int x, int fa) {
for (int i = h[x];~i;i = ne[i]) {
int j = e[i];
if (j == fa) continue;
dp (j, x);
if (mx[x] <= mx[j] + cost[i])
_mx[x] = mx[x],
mx[x] = mx[j] + cost[i];
else if (_mx[x] < mx[j] + cost[i])
_mx[x] = mx[j] + cost[i];
ans = max (ans, _mx[x] + mx[x]);
}
}
int main()
{
memset (h, -1, sizeof h);
cin >> m;
for (int i = 1;i < m;i ++) {
cin >> u >> v >> w;
add (u, v, w);
add (v, u, w);
}
dp (1, 0);
cout << ans * 10 + ans * (ans + 1) / 2 << endl;
return 0;
}