题目大意
给定一棵大小为
n
的树,和一个限制
现在要给每个结点赋值为一个整数,范围为
[1,m]
,且要求树上相邻两点间权值之差大于等于给定的
k
。
求所有合法方案数。
T组数据。
Data Constraint
题解
设状态
f[i][j]
表示第
i
个结点,取值为
那么
f[i][j]=Π(∑|k−j|≥kf[son][k])
然后可以发现对于每个 f[i] ,它前面 (n−1)∗k 个元素与最后 (n−1)∗k 个元素对称,中间是连续一段相同的数。所以只要求出前 (n−1)∗k 个元素的之就可以了。
SRC
#include<cstdio>
#include<cstdlib>
#include<cstring>
#include<iostream>
#include<algorithm>
using namespace std ;
#define N 100 + 10
#define M 100000 + 10
typedef long long ll ;
const int MO = 1e9 + 7 ;
int Node[2*N] , Next[2*N] , Head[N] , tot ;
ll f[N][M] , g[N] , S[N][M] ;
int T , n , m , K ;
int Size = 10000 ;
void link( int u , int v ) {
Node[++tot] = v ;
Next[tot] = Head[u] ;
Head[u] = tot ;
}
inline ll GetSum1( int x , int l ) {
if ( l < 1 || l > m ) return 0 ;
int B = Size ;
if ( l <= min( m , B ) ) return S[x][l] ;
else {
if ( l <= m - B ) return (S[x][B] + (ll)(l - B) * g[x] % MO) % MO ;
else return ((S[x][B] + (ll)(m - 2 * B) * g[x] % MO) % MO + (S[x][B] - S[x][m-l] + MO) % MO) % MO ;
}
}
inline ll GetSum2( int x , int r ) {
if ( r < 1 || r > m ) return 0 ;
int B = Size ;
if ( m <= B ) return (S[x][m] - S[x][r-1] + MO) % MO ;
if ( r > m - B ) return S[x][m-r+1] ;
if ( r > B ) return (S[x][B] + (ll)(m - B - r + 1) * g[x] % MO) % MO ;
else return ((S[x][B] + (ll)(m - 2 * B) * g[x] % MO) + (S[x][B] - S[x][r-1] + MO) % MO) % MO ;
}
void DFS( int x , int Fa ) {
g[x] = 1 ;
for (int y = 1 ; y <= Size ; y ++ ) f[x][y] = 1 ;
for (int p = Head[x] ; p ; p = Next[p] ) {
if ( Node[p] == Fa ) continue ;
if ( Node[p] == 36 ) {
Node[p] ++ ;
Node[p] -- ;
}
DFS( Node[p] , x ) ;
for (int y = 1 ; y <= min( m , Size + 1 ) ; y ++ ) {
ll ret = (GetSum1( Node[p] , y - K ) + GetSum2( Node[p] , y + K )) % MO ;
if ( ret < 0 ) {
ret ++ ;
ret -- ;
}
if ( y == Size + 1 ) g[x] = ((ll)g[x] * ret) % MO ;
else f[x][y] = ((ll)f[x][y] * ret) % MO ;
}
}
for (int y = 1 ; y <= Size ; y ++ ) S[x][y] = (S[x][y-1] + f[x][y]) % MO ;
}
int Power( int x , int k ) {
int s = 1 ;
while ( k ) {
if ( k & 1 ) s = (ll)s * x % MO ;
x = (ll)x * x % MO ;
k /= 2 ;
}
return s ;
}
int main() {
freopen( "label.in" , "r" , stdin ) ;
freopen( "label.out" , "w" , stdout ) ;
scanf( "%d" , &T ) ;
while ( T -- ) {
tot = 0 ;
Size = 10000 ;
memset( f , 0 , sizeof(f) ) ;
memset( g , 0 , sizeof(g) ) ;
memset( S , 0 , sizeof(S) ) ;
memset( Head , 0 , sizeof(Head) ) ;
scanf( "%d%d%d" , &n , &m , &K ) ;
for (int i = 1 ; i < n ; i ++ ) {
int x , y ;
scanf( "%d%d" , &x , &y ) ;
link( x , y ) ;
link( y , x ) ;
}
if ( !K ) { printf( "%d\n" , Power( m , n ) ) ; continue ; }
if ( (n - 1) * K > Size ) Size = (n - 1) * K ;
DFS( 1 , 0 ) ;
printf( "%lld\n" , GetSum1( 1 , m ) ) ;
}
return 0 ;
}
以上.