题目链接:Just A Mistake
考虑先枚举根,求根必选的排列数,这样可以算出根对答案的贡献。然后,设dp[i][j]表示以i为根的子树,i排在第j位且i必选的方案数。考虑子树v,如果v的位置在i的前面,那么v必须不选,那么我们把他减掉;如果v的位置在i后面,那样在我们选了i以后自然不会选v。然后就是对两个排列求并的方案数,这个和【HDU】5789 Permutation一个道理,都是枚举一个排列有多少个在另一个排列的特定位置前,算个组合数就好。
PS:一开始没有想容斥,想用dp[i][j][0~3]表示,dp[i][j][0]表示根必选,i在位置j的方案数;dp[i][j][1]表示根不选,距离根最近的点到根的距离1且在排列上最远的位置是j的方案数;dp[i][0][2]表示距离根最近的点到根的距离为2的方案数(此时j无用)。然后写了好久发现脑子已经成浆糊了,果断抛弃……
#include <bits/stdc++.h>
using namespace std ;
typedef long long LL ;
typedef pair < int , int > pii ;
#define clr( a , x ) memset ( a , x , sizeof a )
const int MAXN = 205 ;
const int mod = 1e9 + 7 ;
vector < int > G[MAXN] ;
int f[MAXN] , vf[MAXN] , c[MAXN][MAXN] ;
int dp[MAXN][MAXN] ;
int nxt[MAXN] ;
int siz[MAXN] ;
int n ;
void up ( int& x , int y ) {
x += y ;
if ( x >= mod ) x -= mod ;
}
void dfs ( int u , int fa ) {
siz[u] = 1 ;
dp[u][1] = 1 ;
for ( int i = 0 ; i < G[u].size () ; ++ i ) {
int v = G[u][i] ;
if ( v == fa ) continue ;
dfs ( v , u ) ;
int sum = f[siz[v]] , num = siz[v] , tot = siz[v] + siz[u] ;
for ( int j = 0 ; j <= tot ; ++ j ) {
nxt[j] = 0 ;
}
for ( int j = 0 , num = siz[v] ; j <= siz[v] ; ++ j , -- num ) {
up ( sum , mod - dp[v][j] ) ;
for ( int k = 1 ; k <= siz[u] ; ++ k ) {
up ( nxt[j + k] , 1LL * sum * dp[u][k] % mod * c[j + k - 1][j] % mod * c[tot - j - k][num] % mod ) ;
}
}
siz[u] = tot ;
for ( int j = 0 ; j <= tot ; ++ j ) {
dp[u][j] = nxt[j] ;
}
}
}
void solve () {
scanf ( "%d" , &n ) ;
for ( int i = 1 ; i <= n ; ++ i ) {
G[i].clear () ;
}
for ( int i = 1 ; i < n ; ++ i ) {
int u , v ;
scanf ( "%d%d" , &u , &v ) ;
G[u].push_back ( v ) ;
G[v].push_back ( u ) ;
}
int ans = 0 ;
for ( int i = 1 ; i <= n ; ++ i ) {
clr ( dp , 0 ) ;
dfs ( i , i ) ;
for ( int j = 1 ; j <= n ; ++ j ) {
up ( ans , dp[i][j] ) ;
}
}
printf ( "%d\n" , ans ) ;
}
int main () {
int T ;
f[0] = 1 ;
clr ( c , 0 ) ;
c[0][0] = 1 ;
for ( int i = 1 ; i < MAXN ; ++ i ) {
f[i] = 1LL * f[i - 1] * i % mod ;
c[i][0] = c[i][i] = 1 ;
for ( int j = 1 ; j < i ; ++ j ) {
c[i][j] = ( c[i - 1][j - 1] + c[i - 1][j] ) % mod ;
}
}
scanf ( "%d" , &T ) ;
for ( int i = 1 ; i <= T ; ++ i ) {
printf ( "Case #%d: " , i ) ;
solve () ;
}
return 0 ;
}