834. 树中距离之和
给定一个无向、连通的树。树中有 N 个标记为 0…N-1 的节点以及 N-1 条边 。
第 i 条边连接节点 edges[i][0] 和 edges[i][1] 。
返回一个表示节点 i 与其他所有节点距离之和的列表 ans。
思路
如图,节点2到其他所有节点距离之和可以分解为两部分:
- 节点2内部子树节点到节点2的距离之和。
- 节点2外部节点到节点2的距离之和。
定义dist[i]表示第i个节点内部子树节点到节点i的距离之和,那么dist[2]怎么求呢?
- dist[2] += dist[3] + 2,因为节点3子树有两个节点,需要加上2次红色路径。
- dist[2] += dist[4] + 1,节点4只有一个节点。
因此,还需要定义nodeNum[i]表示节点i所在子树的节点数量,初始化均为1。思路基本形成了,在递归树的同时计算节点数量,同时计算dist,那么递归的base case是什么呢?很容易想到,叶子节点就是base case,此时dist[i] = 0,nodeNum[i] = 1,其只有一条路径连向父节点。只要从叶子节点开始处理,更新dist和nodeNum即可。先处理叶子节点,再处理根节点的递归树方式是什么呢?很明显就是后序遍历,因此代码如下:
后序遍历处理内部子树
void postOrder(int u, int v) {
//graph是类似邻接表,存储节点的邻居信息
for(auto& neighbor : graph[u]) {
//当前节点的邻居为父节点,跳过
//邻居只有父节点,base case结束递归
if(neighbor == v) {
continue;
}
//后序遍历
postOrder(neighbor, u);
//计算dist
dist[u] += dist[neighbor] + nodeNum[neighbor];
//更新nodeNum
nodeNum[u] += nodeNum[neighbor];
}
}
处理好节点2内部子树节点到节点2的距离之和之后,也就是计算好dist之后,还需要计算节点2外部节点到节点2的距离之和。
如图,节点0的dist[0]是正确的,因为其是根节点,可不可以从这个正确的dist[0]推出其他节点正确的dist[i]呢?如图,dist[2]是存储了节点2内部子树节点到节点2的距离,dist[0]是存储了节点0内部子树节点到节点2的距离。
- 节点2外部子树的全部节点到节点0的距离 + 节点2外部子树全部节点数 = 节点2外部子树的全部节点到节点2的距离
- 节点2内部子树的全部节点带节点0的距离 - 节点2内部字数全部节点数 = 节点2内部子树的全部节点到节点2的距离
- 两者相加,其实用公式来表示就是dist[0] + N - nodeNum[2] - nodeNum[2] = dist[2]
因而可以从根节点不断更新子节点的dist值,从根节点更新子节点,使用的是先序遍历。
先序遍历处理外部子树
void preOrder(int u, int v) {
for(auto& neighbor : graph[u]) {
if(neighbor == v) {
continue;
}
dist[neighbor] = dist[u] - nodeNum[neighbor] + nodeNum.size() - nodeNum[neighbor];
preOrder(neighbor, u);
}
}
综上,只需要后序遍历+先序遍历就可以解决问题。
全部代码
class Solution {
public:
vector<int> dist, nodeNum;
vector<vector<int>> graph;
void postOrder(int u, int v) {
for(auto& neighbor : graph[u]) {
if(neighbor == v) {
continue;
}
postOrder(neighbor, u);
dist[u] += dist[neighbor] + nodeNum[neighbor];
nodeNum[u] += nodeNum[neighbor];
}
}
void preOrder(int u, int v) {
for(auto& neighbor : graph[u]) {
if(neighbor == v) {
continue;
}
dist[neighbor] = dist[u] - nodeNum[neighbor] + nodeNum.size() - nodeNum[neighbor];
preOrder(neighbor, u);
}
}
vector<int> sumOfDistancesInTree(int N, vector<vector<int>>& edges) {
//距离初始化为0
dist.resize(N, 0);
//节点数初始化为1
nodeNum.resize(N, 1);
graph.resize(N);
//构建图邻接表
for(auto& edge : edges) {
int u = edge[0], v = edge[1];
graph[u].push_back(v);
graph[v].push_back(u);
}
//后序遍历
postOrder(0, -1);
//前序遍历
preOrder(0, -1);
return dist;
}
};