题意:题目给了一个N NN个点N−1 N-1N−1条边的树形图
现在每次出行的交通工具是鱼,而鱼对于路径长度有不同的喜好
如果总路径长度是3的倍数,那么鱼需要路径长度数量的 榴莲
如果总路径长度%3=1,那么鱼需要路径长度数量的 木瓜
如果总路径长度%3=2,那么鱼需要路径长度数量的 牛奶果
现在问,从图上的每个点i 到达除了i 之外的所有点(N∗(N−1)条路径),总共需要花费多少榴莲、木瓜、牛奶果
解析:题目要我们求任意两点间距离对答案的贡献。d[i] 表示点 i 的子节点到点 i 的距离之和, num[i] 表示 点 i 及其后代的数量, 距离和后代可以分为 3 种, d % 3 = 0, d%3 = 1, d % 3 = 2, 所以可以分为 d[i][j] , num[i][j]。
可以得到递推方程:
// w 表示 u——v 距离, v 是 u 的子节点 num[v][0]++; for(int j = 0; j <= 2; j++){ d[u][(j+w)%3] += num[v][j]*w + d[v][j]; num[u][(j+w)%3] += num[v][j]; }
通过一次dfs(1, 0)可以得到 以 1 为根节点时, 所有点到 点 1的距离之和, 但是这只是记录了点 1 和 任意 一节点之间距离对答案的贡献, 还需要求 以其他节点为根, 对答案的贡献, 还进行n - 1次dfs?? 不需要, 通过换根dp只需要再来一次dfs就可以。
具体解析可以看这个大佬的, 。别人写的特别好https://blog.csdn.net/weixin_44282912/article/details/100833858
//https://nanti.jisuanke.com/t/41403
#include<bits/stdc++.h>
using namespace std;
typedef long long ll;
const int N = 1e4+10;
const int mod = 1e9+7;
int n;
struct Edge{
int to, next, w;
}e[N<<1];
int head[N], tot;
ll d[N][3], num[N][3], ans[3];
void addEdge(int u, int v, int w){
e[tot] = Edge{v, head[u], w};
head[u] = tot++;
}
void dfs1(int u, int fa){
for(int i = 0; i <= 2; i++){
d[u][i] = num[u][i] = 0;
}
for(int i = head[u]; i != -1; i = e[i].next){
int v = e[i].to, w = e[i].w;
if(v == fa) continue;
dfs1(v, u);
num[v][0]++;
for(int j = 0; j <= 2; j++){
d[u][(j+w)%3] += num[v][j]*w + d[v][j];
num[u][(j+w)%3] += num[v][j];
}
}
}
void pika(int u){
for(int i = 0; i <= 2; i++){
ans[i] += d[u][i];
ans[i] %= mod;
}
}
void dfs2(int u, int fa){
for(int i = head[u]; i != -1; i = e[i].next){
int v = e[i].to, w = e[i].w;
if(v == fa) continue;
//记录下原来的d,num 后面需要还原
ll d1[3], d2[3], num1[3], num2[3];
for(int j = 0; j <= 2; j++){
d1[j] = d[u][j]; d2[j] = d[v][j];
num1[j] = num[u][j]; num2[j] = num[v][j];
}
//因为现在树根是v,原来的树根u变为v的子节点
//所以原来的树根u的路径数量要减去v的路径数量(也就是dfs1里面倒过来处理一遍)
for(int j = 0; j <= 2; j++){
d[u][(j+w)%3] -= num[v][j]*w + d[v][j];
num[u][(w+j)%3] -= num[v][j];
}
num[v][0]--;
num[u][0]++;
for(int j = 0; j <= 2; j++){
d[v][(w+j)%3] += num[u][j]*w + d[u][j];
num[v][(w+j)%3] += num[u][j];
}
pika(v);
dfs2(v, u);
//还原
for(int j = 0; j <= 2; j++){
d[u][j] = d1[j]; d[v][j] = d2[j];
num[u][j] = num1[j]; num[v][j] = num2[j];
}
}
}
int main(){
while(~scanf("%d", &n)){
tot = 0;
memset(head, -1, sizeof(int)*(n+1));
for(int i = 0; i <= 2; i++) ans[i] = 0;
int u, v, w;
for (int i = 1; i < n; ++i){
scanf("%d%d%d", &u, &v, &w);
v++, u++;
addEdge(u, v, w);
addEdge(v, u, w);
}
dfs1(1, 0);
pika(1);
dfs2(1, 0);
printf("%lld %lld %lld\n", ans[0], ans[1], ans[2]);
}
return 0;
}