这题最后一小时绝杀,一发过,爽啊
Fish eating fruit
题意
给一个树, n n n个顶点,边上有权值,然后两两之间路径有 n ∗ ( n − 1 ) / 2 n*(n-1)/2 n∗(n−1)/2种,假如其中一条长度为 s s s,如果 s m o d 3 = = 0 s\ mod\ 3==0 s mod 3==0, c 1 = c 1 + s c_1=c_1+s c1=c1+s,如果 s m o d 3 = = 1 s\ mod\ 3==1 s mod 3==1, c 2 = c 2 + s c_2=c_2+s c2=c2+s,如果 s m o d 3 = = 2 s\ mod\ 3==2 s mod 3==2, c 3 = c 3 + s c_3=c_3+s c3=c3+s,遍历所有路径,算出最后的 c 1 c_1 c1, c 2 c_2 c2, c 3 c_3 c3。
思路
首先将树以一个顶点拎起来。
然后定义
c
n
t
[
i
]
[
j
]
cnt[i][j]
cnt[i][j]:顶点为
i
i
i,到子节点的长度%3==
j
j
j的节点个数。
定义
d
p
[
i
]
[
j
]
dp[i][j]
dp[i][j]:顶点为
i
i
i,到子节点的长度%3==
j
j
j的长度和。
然后根据儿子的
c
n
t
,
儿
子
的
d
p
,
cnt,儿子的dp,
cnt,儿子的dp,还有到儿子的长度,转移
然后可以得出一个点的贡献
然后考虑换根,每次需要把根节点翻下去,再把儿子节点翻上来,考虑
c
n
t
和
d
p
cnt和dp
cnt和dp的变化,回溯递归
/* Author : Rshs
* Data : 2019-09-14-14.46
*/
#include<bits/stdc++.h>
using namespace std;
#define FI first
#define SE second
#define LL long long
#define MP make_pair
#define PII pair<int,int>
#define SZ(a) (int)a.size()
const double pai = acos(-1);
const double eps = 1e-10;
const LL mod = 1e9+7;
const int MX = 1e6+5;
struct no{
int v,d;
};
vector<no>g[MX];
LL dp[MX][3];
LL cnt[MX][3];
LL ans[3];
void dfs(int now,int fa){
for(auto i:g[now]){
if(i.v==fa) continue;
dfs(i.v,now);
cnt[now][i.d%3]++;
dp[now][i.d%3]=(dp[now][i.d%3]+(LL)i.d)%mod;
for(int j=0;j<3;j++){
int z=(j+i.d)%3;
cnt[now][z]+=cnt[i.v][j];
}
for(int j=0;j<3;j++){
int z=(j+i.d)%3;
dp[now][z]=(dp[now][z]+dp[i.v][j]+cnt[i.v][j]*(LL)i.d%mod)%mod;
}
}
}
void DFS(int now,int fa){
for(int i=0;i<3;i++) ans[i]=(ans[i]+dp[now][i])%mod;
for(auto i:g[now]){
if(i.v==fa)continue;
LL aa[3],bb[3],cc[3],dd[3];
for(int j=0;j<3;j++) aa[j]=cnt[now][j],bb[j]=dp[now][j],cc[j]=cnt[i.v][j],dd[j]=dp[i.v][j];
dp[now][i.d%3]=((dp[now][i.d%3]-(LL)i.d)%mod+mod)%mod;
cnt[now][i.d%3]--;
for(int j=0;j<3;j++){
int z=(j+i.d)%3;
cnt[now][z]-=cnt[i.v][j];
}
for(int j=0;j<3;j++){
int z=(i.d+j)%3;
dp[now][z]=(dp[now][z]-dp[i.v][j]-cnt[i.v][j]*(LL)i.d%mod)%mod;
dp[now][z]+=mod;dp[now][z]%=mod;
}
cnt[i.v][i.d%3]++;
dp[i.v][i.d%3]=(dp[i.v][i.d%3]+(LL)i.d)%mod;
for(int j=0;j<3;j++){
int z=(j+i.d)%3;
cnt[i.v][z]+=cnt[now][j];
}
for(int j=0;j<3;j++){
int z=(j+i.d)%3;
dp[i.v][z]=(dp[i.v][z]+dp[now][j]+cnt[now][j]*(LL)i.d%mod)%mod;
}
DFS(i.v,now);
for(int j=0;j<3;j++) cnt[now][j]=aa[j],dp[now][j]=bb[j],cnt[i.v][j]=cc[j],dp[i.v][j]=dd[j];
}
}
int main(){
int n;
while(cin>>n){
for(int i=0;i<=n;i++) g[i].clear();
for(int i=0;i<=n;i++)for(int j=0;j<3;j++) dp[i][j]=cnt[i][j]=0,ans[j]=0;
for(int i=1;i<n;i++){
int sa,sb,sc;scanf("%d%d%d",&sa,&sb,&sc);
g[sa].push_back(no{sb,sc});
g[sb].push_back(no{sa,sc});
}
dfs(0,-1);
DFS(0,-1);
cout<<ans[0]<<' ' <<ans[1]<< ' '<<ans[2]<<'\n';
}
return 0;
}