题意
就是分别统计所有有向路径长度
%
3
\%3
% 3 意义下的总权值
题解
显然是个换根
d
p
dp
d p ,需要统计的信息不仅包括模3意义下的总权值,还应包括模3意义下的路径条数,转移起来有点麻烦,建议写树分治
代码
#include <bits/stdc++.h>
using namespace std;
const int maxn= 2e5 + 10 ;
const long long mod= 1e9 + 7 ;
vector< pair< int , int > > vec[ maxn] ;
int n;
long long dp[ maxn] [ 2 ] [ 3 ] , cnt[ maxn] [ 2 ] [ 3 ] , ans[ 3 ] ;
void dfs1 ( int cur, int fa)
{
for ( int i= 0 ; i< vec[ cur] . size ( ) ; i++ ) {
int to= vec[ cur] [ i] . first, w= vec[ cur] [ i] . second;
if ( to!= fa) dfs1 ( to, cur) ;
}
for ( int i= 0 ; i< vec[ cur] . size ( ) ; i++ ) {
int to= vec[ cur] [ i] . first, w= vec[ cur] [ i] . second;
if ( to!= fa) {
cnt[ cur] [ 0 ] [ w% 3 ] ++ ; dp[ cur] [ 0 ] [ w% 3 ] = ( dp[ cur] [ 0 ] [ w% 3 ] + w) % mod;
for ( int j= 0 ; j<= 2 ; j++ ) {
cnt[ cur] [ 0 ] [ ( j+ w) % 3 ] + = cnt[ to] [ 0 ] [ j] ;
dp[ cur] [ 0 ] [ ( j+ w) % 3 ] = ( dp[ cur] [ 0 ] [ ( j+ w) % 3 ] + dp[ to] [ 0 ] [ j] + 1LL * cnt[ to] [ 0 ] [ j] * w% mod) % mod;
}
}
}
}
void dfs2 ( int cur, int fa)
{
for ( int i= 0 ; i<= 2 ; i++ ) ans[ i] = ( ans[ i] + dp[ cur] [ 0 ] [ i] + dp[ cur] [ 1 ] [ i] ) % mod;
long long dp_[ 3 ] , cnt_[ 3 ] ;
for ( int i= 0 ; i<= 2 ; i++ ) dp_[ i] = dp[ cur] [ 1 ] [ i] + dp[ cur] [ 0 ] [ i] , cnt_[ i] = cnt[ cur] [ 1 ] [ i] + cnt[ cur] [ 0 ] [ i] ;
for ( int i= 0 ; i< vec[ cur] . size ( ) ; i++ ) {
int to= vec[ cur] [ i] . first, w= vec[ cur] [ i] . second;
if ( to!= fa) {
cnt_[ w% 3 ] -- ; dp_[ w% 3 ] = ( ( dp_[ w% 3 ] - w) % mod+ mod) % mod;
for ( int i= 0 ; i<= 2 ; i++ ) {
cnt_[ ( i+ w) % 3 ] - = cnt[ to] [ 0 ] [ i] ;
dp_[ ( i+ w) % 3 ] = ( ( dp_[ ( i+ w) % 3 ] - dp[ to] [ 0 ] [ i] - 1LL * cnt[ to] [ 0 ] [ i] * w% mod) % mod+ mod) % mod;
}
cnt[ to] [ 1 ] [ w% 3 ] ++ ; dp[ to] [ 1 ] [ w% 3 ] = ( dp[ to] [ 1 ] [ w% 3 ] + w) % mod;
for ( int i= 0 ; i<= 2 ; i++ ) {
dp[ to] [ 1 ] [ ( i+ w) % 3 ] = ( dp[ to] [ 1 ] [ ( i+ w) % 3 ] + dp_[ i] + 1LL * cnt_[ i] * w% mod) % mod;
cnt[ to] [ 1 ] [ ( i+ w) % 3 ] + = cnt_[ i] ;
}
cnt_[ w% 3 ] ++ ; dp_[ w% 3 ] = ( dp_[ w% 3 ] + w) % mod;
for ( int i= 0 ; i<= 2 ; i++ ) {
cnt_[ ( i+ w) % 3 ] + = cnt[ to] [ 0 ] [ i] ;
dp_[ ( i+ w) % 3 ] = ( ( dp_[ ( i+ w) % 3 ] + dp[ to] [ 0 ] [ i] + 1LL * cnt[ to] [ 0 ] [ i] * w% mod) % mod+ mod) % mod;
}
}
}
for ( int i= 0 ; i< vec[ cur] . size ( ) ; i++ ) {
int to= vec[ cur] [ i] . first, w= vec[ cur] [ i] . second;
if ( to!= fa) dfs2 ( to, cur) ;
}
}
int main ( )
{
while ( ~ scanf ( "%d" , & n) ) {
for ( int i= 1 , u, v, w; i< n; i++ ) {
scanf ( "%d %d %d" , & u, & v, & w) ;
u++ ; v++ ;
vec[ u] . push_back ( make_pair ( v, w) ) ;
vec[ v] . push_back ( make_pair ( u, w) ) ;
}
dfs1 ( 1 , 0 ) ;
dfs2 ( 1 , 0 ) ;
printf ( "%lld %lld %lld\n" , ans[ 0 ] , ans[ 1 ] , ans[ 2 ] ) ;
for ( int i= 1 ; i<= n; i++ ) vec[ i] . clear ( ) ;
for ( int i= 1 ; i<= n; i++ ) for ( int j= 0 ; j<= 1 ; j++ ) for ( int k= 0 ; k<= 2 ; k++ ) dp[ i] [ j] [ k] = cnt[ i] [ j] [ k] = 0 ;
for ( int i= 0 ; i<= 2 ; i++ ) ans[ i] = 0 ;
}
}