【题目】*834. 树中距离之和
【解题思路1】树形动态规划
首先考虑一个节点的情况,即每次题目指定一棵树,以 root 为根,询问节点 root 与其他所有节点的距离之和。
dp数组的含义: 定义 dp[u] 表示以 u 为根的子树,它的所有子节点到它的距离之和,同时定义 sz[u] 表示以 u 为根的子树的节点数量
动态转移方程: d p [ u ] = ∑ v ∈ s o n [ u ] d p [ v ] + s z [ v ] dp[u]=\sum_{v\in son[u]} dp[v] + sz[v] dp[u]=∑v∈son[u]dp[v]+sz[v]
其中son[u] 表示 u 的所有后代节点集合。转移方程表示的含义就是考虑每个后代节点 v,已知 v 的所有子节点到它的距离之和为 dp[v],那么这些节点到 u 的距离之和还要考虑 u→v 这条边的贡献。考虑这条边长度为 1,一共有 sz[v] 个节点到节点 u 的距离会包含这条边,因此贡献即为 1×sz[v]=sz[v]。遍历整棵树,从叶子节点开始自底向上递推到根节点 root 即能得出最后的答案为 dp[root]。
回到本题中,题目要求的其实是上题的扩展,即要求求出每个节点为根节点的时候,它与其他所有节点的距离之和。暴力的角度我们可以考虑对每个节点都做一次如上的树形动态规划,这样时间复杂度即为 O ( N 2 ) O(N^2) O(N2),那么有没有更优雅的方法呢?
经过一次树形动态规划后其实获得了在 u 为根的树中,每个节点为根的子树的答案 dp,我们可以利用这些已有信息来优化时间复杂
假设 u 的某个后代节点为 v,如果要算 v 的答案,本来我们要以 v 为根来进行一次树形动态规划。但是利用已有的信息,我们可以考虑树的形态做一次改变,让 v 换到根的位置,u 变为其孩子节点,同时维护原有的 dp 信息。在这一次的转变中,观察到除了 u 和 v 的 dp 值,其他节点的 dp 值都不会改变,因此只要更新 dp[u] 和 dp[v] 的值即可。
那么我们来看 v 换到根的位置的时候怎么利用已有信息求出 dp[u] 和 dp[v] 的值。重新回顾第一次树形动态规划的转移方程,我们可以知道当 u 变为 v 的孩子的时候 v 不在 u 的后代集合 son[u] 中了,因此此时 dp[u] 需要减去 v 的贡献,即
d p [ u ] = d p [ u ] − ( d p [ v ] + s z [ v ] ) dp[u]=dp[u]-(dp[v]+sz[v]) dp[u]=dp[u]−(dp[v]+sz[v])
同时 sz[u] 也要相应减去 sz[v]。
而 v 的后代节点集合中多出了 u,因此 dp[v] 的值要由 u 更新上来,即
d p [ v ] = d p [ v ] + ( d p [ u ] + s z [ u ] ) dp[v]=dp[v]+(dp[u]+sz[u]) dp[v]=dp[v]+(dp[u]+sz[u])
同时 sz[v] 也要相应加上 sz[u]。
至此我们完成了一次「换根」操作,在 O(1) 的时间内维护了 \textit{dp}dp 的信息,且此时的树结构以 v 为根。那么接下来我们不断地进行换根的操作,即能在 O(N) 的时间内求出每个节点为根的答案,实现了时间复杂度的优化。
class Solution {
int[] ans;
int[] sz;
int[] dp;
List<List<Integer>> graph;
public int[] sumOfDistancesInTree(int N, int[][] edges) {
ans = new int[N];
sz = new int[N];
dp = new int[N];
graph = new ArrayList<List<Integer>>();
for (int i = 0; i < N; ++i) {
graph.add(new ArrayList<Integer>());
}
for (int[] edge: edges) {
int u = edge[0], v = edge[1];
graph.get(u).add(v);
graph.get(v).add(u);
}
dfs(0, -1);
dfs2(0, -1);
return ans;
}
public void dfs(int u, int f) {
sz[u] = 1;
dp[u] = 0;
for (int v: graph.get(u)) {
if (v == f) {
continue;
}
dfs(v, u);
dp[u] += dp[v] + sz[v];
sz[u] += sz[v];
}
}
public void dfs2(int u, int f) {
ans[u] = dp[u];
for (int v: graph.get(u)) {
if (v == f) {
continue;
}
int pu = dp[u], pv = dp[v];
int su = sz[u], sv = sz[v];
dp[u] -= dp[v] + sz[v];
sz[u] -= sz[v];
dp[v] += dp[u] + sz[u];
sz[v] += sz[u];
dfs2(v, u);
dp[u] = pu;
dp[v] = pv;
sz[u] = su;
sz[v] = sv;
}
}
}