Description
Solution
考虑最后的路径是什么样子,首先每一条边最多经过四次,即来回两次,因为我们可以通过递归解决一条边下面的子树,然后再根据这个点当前的奇偶性,考虑经过这条边两次还是四次(即横跳一次还是两次)。 有了这个简单的思路,我们就可以直接树形DP了,一条路径把它拆分到每一条边上计算,记录
f
[
x
]
[
0
/
1
/
2
]
[
0
/
1
]
f[x][0/1/2][0/1]
f [ x ] [ 0 / 1 / 2 ] [ 0 / 1 ] 表示
x
x
x 的子树内,有
0
/
1
/
2
0/1/2
0 / 1 / 2 个路径端点,
x
x
x 的灯是
0
/
1
0/1
0 / 1 ,形成了一个以
x
x
x 开头,以
x
x
x 结尾的操作序列,直接把它跟父亲的操作序列相接,根据此前
x
x
x 灯的状态决定这条边横跳一次还是横跳两次,转移一下即可。
#include <cstdio>
#include <cstring>
#include <algorithm>
#include <cmath>
#define maxn 500005
using namespace std;
int n, i, j, k, a[ maxn] ;
int em, e[ maxn* 2 ] , nx[ maxn* 2 ] , ls[ maxn] ;
int f[ maxn] [ 3 ] [ 2 ] , g[ 3 ] [ 2 ] , inf;
void read ( int & x) {
x= 0 ; char ch= getchar ( ) ;
for ( ; ch< '0' || ch> '9' ; ch= getchar ( ) ) ;
for ( ; ch>= '0' && ch<= '9' ; ch= getchar ( ) ) x= x* 10 + ch- '0' ;
}
void insert ( int x, int y) {
em++ ; e[ em] = y; nx[ em] = ls[ x] ; ls[ x] = em;
em++ ; e[ em] = x; nx[ em] = ls[ y] ; ls[ y] = em;
}
int sz[ maxn] , cnt;
void Min ( int & a, int b) { a= ( a< b) ? a: b; }
void dfs ( int x, int p) {
sz[ x] = a[ x] ^ 1 ;
f[ x] [ 0 ] [ a[ x] ^ 1 ] = f[ x] [ 1 ] [ a[ x] ^ 1 ] = f[ x] [ 2 ] [ a[ x] ^ 1 ] = 1 ;
for ( int i= ls[ x] ; i; i= nx[ i] ) if ( e[ i] != p) {
int y= e[ i] ; dfs ( y, x) , sz[ x] + = sz[ y] ;
if ( sz[ y] ) {
memset ( g, 127 , sizeof ( g) ) ;
for ( int j= 0 ; j< 2 ; j++ ) for ( int k= 0 ; k< 2 ; k++ ) for ( int t1= 0 ; t1< 3 ; t1++ ) for ( int t2= 0 ; t1+ t2< 3 ; t2++ )
if ( f[ x] [ t1] [ j] < 1e9 && f[ y] [ t2] [ k] < 1e9 ) {
if ( t2== 1 ) {
if ( k== 1 )
Min ( g[ t1+ t2] [ j] , f[ x] [ t1] [ j] + f[ y] [ t2] [ k] ) ;
else Min ( g[ t1+ t2] [ j^ 1 ] , f[ x] [ t1] [ j] + f[ y] [ t2] [ k] + 2 ) ;
} else
if ( t2== 2 )
Min ( g[ t1+ t2] [ j^ k] , f[ x] [ t1] [ j] + f[ y] [ t2] [ k] + 3 - ( k^ 1 ) * 2 ) ;
else
Min ( g[ t1+ t2] [ j^ k] , f[ x] [ t1] [ j] + f[ y] [ t2] [ k] + 3 - k* 2 ) ;
}
memcpy ( f[ x] , g, sizeof ( g) ) ;
}
}
if ( sz[ x] == cnt) {
printf ( "%d\n" , f[ x] [ 2 ] [ 1 ] ) ;
exit ( 0 ) ;
}
}
int main ( ) {
read ( n) ; char ch= getchar ( ) ;
while ( ch!= '0' && ch!= '1' ) ch= getchar ( ) ;
for ( i= 1 ; i<= n; i++ ) a[ i] = ch- '0' , ch= getchar ( ) , cnt+ = a[ i] ^ 1 ;
for ( i= 1 ; i< n; i++ ) read ( j) , read ( k) , insert ( j, k) ;
memset ( f, 127 , sizeof ( f) ) , inf= f[ 0 ] [ 0 ] [ 0 ] ;
dfs ( 1 , 0 ) ;
}