Description
给你一个
n
n
n 个点的树,你可以选择一个点在上面随机游走,每次等概率随机跳到一个距离不超过2的点(包括自己)。 现在给出
m
m
m 个标记点,求每一个点跳到任意一个标记点的期望步数。
n
,
m
≤
1
e
5
n,m\le1e5
n , m ≤ 1 e 5
Solution
考虑从叶子往上面推,那么一个点的期望
E
(
x
)
E(x)
E ( x ) 可以表示成
s
u
m
[
f
a
x
]
,
E
(
f
a
x
)
,
E
(
f
a
f
a
x
)
sum[fa_x],E(fa_x),E(fa_{fa_x})
s u m [ f a x ] , E ( f a x ) , E ( f a f a x ) 的和,其中
s
u
m
[
x
]
=
∑
E
(
s
o
n
x
)
sum[x]=\sum E(son_x)
s u m [ x ] = ∑ E ( s o n x ) 考虑一个点,要解出所有儿子的
E
(
x
)
E(x)
E ( x ) ,而这些
E
(
x
)
E(x)
E ( x ) 还互相有关。 实际上我们可以把所有
E
(
x
)
=
.
.
.
E(x)=...
E ( x ) = . . . 的方程加在一起,这样左边就是
s
u
m
sum
s u m 了,就可以把
s
u
m
sum
s u m 解出来了,这样就可以表示为
E
(
f
a
x
)
,
E
(
f
a
f
a
x
)
E(fa_x),E(fa_{fa_x})
E ( f a x ) , E ( f a f a x ) 的和了。 最后再从根节点推下来即可。
#include <cstdio>
#include <cstring>
#include <algorithm>
#include <cmath>
#define maxn 100005
#define ll long long
#define mo 998244353
using namespace std;
int n, m, i, j, k, bz[ maxn] , du[ maxn] , cnt[ maxn] ;
int em, e[ maxn* 2 ] , nx[ maxn* 2 ] , ls[ maxn] , fa[ maxn] ;
ll f[ maxn] [ 3 ] , g[ maxn] , inv[ maxn] ;
ll ksm ( ll x, ll y) {
ll s= 1 ;
for ( ; y; y/ = 2 , x= x* x% mo) if ( y& 1 )
s= s* x% mo;
return s;
}
void insert ( int x, int y) {
du[ x] ++ , du[ y] ++ ;
em++ ; e[ em] = y; nx[ em] = ls[ x] ; ls[ x] = em;
em++ ; e[ em] = x; nx[ em] = ls[ y] ; ls[ y] = em;
}
ll s[ 4 ] ;
void dfs ( int x, int p) {
fa[ x] = p;
if ( ! bz[ x] ) {
f[ x] [ 0 ] = inv[ cnt[ x] - 1 ] * cnt[ x] % mo;
if ( fa[ x] ) f[ x] [ 1 ] = inv[ cnt[ x] - 1 ] , g[ x] = inv[ cnt[ x] - 1 ] ;
if ( fa[ fa[ x] ] ) f[ x] [ 2 ] = inv[ cnt[ x] - 1 ] ;
}
for ( int i= ls[ x] ; i; i= nx[ i] ) if ( e[ i] != p) dfs ( e[ i] , x) ;
s[ 0 ] = s[ 1 ] = s[ 2 ] = 0 ; ll psum= 0 ;
for ( int i= ls[ x] ; i; i= nx[ i] ) if ( e[ i] != p) {
ll tmp= ksm ( g[ e[ i] ] + 1 , mo- 2 ) ;
( f[ e[ i] ] [ 0 ] * = tmp) % = mo;
( f[ e[ i] ] [ 1 ] * = tmp) % = mo;
( f[ e[ i] ] [ 2 ] * = tmp) % = mo;
( g[ e[ i] ] * = tmp) % = mo;
( s[ 0 ] + = f[ e[ i] ] [ 0 ] ) % = mo;
( s[ 1 ] + = f[ e[ i] ] [ 1 ] ) % = mo;
( s[ 2 ] + = f[ e[ i] ] [ 2 ] ) % = mo;
( psum+ = g[ e[ i] ] ) % = mo;
}
ll Inv= ksm ( mo+ 1 - psum, mo- 2 ) ;
s[ 0 ] = s[ 0 ] * Inv% mo, s[ 1 ] = s[ 1 ] * Inv% mo, s[ 2 ] = s[ 2 ] * Inv% mo;
for ( int i= ls[ x] ; i; i= nx[ i] ) if ( e[ i] != p) {
( f[ e[ i] ] [ 0 ] + = s[ 0 ] * g[ e[ i] ] ) % = mo;
( f[ e[ i] ] [ 1 ] + = s[ 1 ] * g[ e[ i] ] ) % = mo;
( f[ e[ i] ] [ 2 ] + = s[ 2 ] * g[ e[ i] ] ) % = mo;
}
if ( ! bz[ x] ) {
s[ 0 ] = s[ 1 ] = s[ 2 ] = s[ 3 ] = 0 ;
for ( int i= ls[ x] ; i; i= nx[ i] ) if ( e[ i] != p) {
int y= e[ i] ; ll sumy= 0 ;
for ( int j= ls[ y] ; j; j= nx[ j] ) if ( e[ j] != x) {
int z= e[ j] ;
( s[ 0 ] + = f[ z] [ 0 ] ) % = mo;
( sumy+ = f[ z] [ 1 ] ) % = mo;
( s[ 1 ] + = f[ z] [ 2 ] ) % = mo;
}
sumy++ ;
( s[ 0 ] + = f[ y] [ 0 ] * sumy) % = mo;
( s[ 1 ] + = f[ y] [ 1 ] * sumy) % = mo;
( s[ 2 ] + = f[ y] [ 2 ] * sumy) % = mo;
}
s[ 0 ] = s[ 0 ] * inv[ cnt[ x] - 1 ] % mo;
s[ 1 ] = s[ 1 ] * inv[ cnt[ x] - 1 ] % mo;
s[ 2 ] = s[ 2 ] * inv[ cnt[ x] - 1 ] % mo;
( f[ x] [ 0 ] + = s[ 0 ] ) % = mo;
( f[ x] [ 1 ] + = s[ 2 ] ) % = mo;
Inv= ksm ( mo+ 1 - s[ 1 ] , mo- 2 ) ;
f[ x] [ 0 ] = f[ x] [ 0 ] * Inv% mo;
f[ x] [ 1 ] = f[ x] [ 1 ] * Inv% mo;
f[ x] [ 2 ] = f[ x] [ 2 ] * Inv% mo;
g[ x] = g[ x] * Inv% mo;
}
}
ll ans[ maxn] ;
void dfs2 ( int x, int p) {
ans[ x] = f[ x] [ 0 ] ;
if ( fa[ x] ) ( ans[ x] + = ans[ fa[ x] ] * f[ x] [ 1 ] ) % = mo;
if ( fa[ fa[ x] ] ) ( ans[ x] + = ans[ fa[ fa[ x] ] ] * f[ x] [ 2 ] ) % = mo;
for ( int i= ls[ x] ; i; i= nx[ i] ) if ( e[ i] != p)
dfs2 ( e[ i] , x) ;
}
int main ( ) {
freopen ( "ceshi.in" , "r" , stdin ) ;
freopen ( "ceshi.out" , "w" , stdout ) ;
scanf ( "%d%d" , & n, & m) ;
inv[ 0 ] = 1 ; for ( i= 1 ; i<= n; i++ ) inv[ i] = ksm ( i, mo- 2 ) ;
for ( i= 1 ; i< n; i++ ) scanf ( "%d%d" , & j, & k) , insert ( j, k) ;
for ( i= 1 ; i<= m; i++ ) scanf ( "%d" , & k) , bz[ k] = 1 ;
for ( int x= 1 ; x<= n; x++ ) {
cnt[ x] = du[ x] + 1 ;
for ( i= ls[ x] ; i; i= nx[ i] ) cnt[ x] + = du[ e[ i] ] - 1 ;
}
dfs ( 1 , 0 ) ;
dfs2 ( 1 , 0 ) ;
for ( i= 1 ; i<= n; i++ ) printf ( "%lld\n" , ans[ i] ) ;
}