很棘手的树形 d p dp dp 。
设 d p [ u ] [ x ] dp[u][x] dp[u][x] 表示以 u u u 为根的子树,有 x x x 个节点失配的方案数。这样时间复杂度 O ( n 3 ) O(n^3) O(n3) 。
一个想法是直接继承重儿子,用 d s u − o n − t r e e dsu-on-tree dsu−on−tree 优化转移。不过比较难写,而且 n < = 5000 n<=5000 n<=5000 似乎过不去。
但是看到 任意边都必须有颜色 时想到容斥。具体来说,就是 ∑ ( − 1 ) m ∗ 有 m 条 边 没 有 颜 色 的 方 案 数 \sum(-1)^m*有m条边没有颜色的方案数 ∑(−1)m∗有m条边没有颜色的方案数 。后文的方案数均指带上容斥系数的方案数。
考虑一条边只有选和不选两种决策,每次加入一条边,只考虑当前边有没有颜色覆盖,具体来说就是和之前余下的点是否形成一个连通块。
下面叙述正解:
首先考虑 2 n 2n 2n 个节点的自由匹配数。答案为 A 2 n n 2 n = ( 2 n − 1 ) ! ! \frac{A_{2n}^n}{2^n}=(2n-1)!! 2nA2nn=(2n−1)!! 。比较容易理解的方法是考虑节点 1 1 1 有 2 n − 1 2n-1 2n−1 种连法,节点 2 2 2 有 2 n − 3 2n-3 2n−3 种连法,直到 n = 1 n=1 n=1 时有一种连法。
然后考虑 F F F 条边断掉,此时方案数为 g ( n 1 ) ∗ g ( n 2 ) . . . g ( n ∣ F + 1 ∣ ) g(n_1)*g(n_2)...g(n_{|F+1|}) g(n1)∗g(n2)...g(n∣F+1∣) 。其中 g ( n 1 ) g(n1) g(n1) 表示连通块 n 1 n1 n1 的方案数。
其实你会发现枚举断边是套路。如果 n < = 20 n<=20 n<=20 可以直接状压枚举,但是因为本题形态是一个树,可以树形 dp 优化。
设 d p [ u ] [ x ] dp[u][x] dp[u][x] 表示以 u u u 为根的子树,有 x x x 个节点组成的连通块的方案数。
具体地,每多一条断边,就乘上系数 − 1 -1 −1 。有状态转移方程:
- d p [ u ] [ x ] ∗ d p [ v ] [ y ] − > d p [ u ] [ x + y ] dp[u][x]*dp[v][y]->dp[u][x+y] dp[u][x]∗dp[v][y]−>dp[u][x+y]
- − d p [ u ] [ x ] ∗ d p [ v ] [ y ] ∗ g [ y ] − > d p [ u ] [ x ] -dp[u][x]*dp[v][y]*g[y]->dp[u][x] −dp[u][x]∗dp[v][y]∗g[y]−>dp[u][x]
最后枚举 x = [ 0 , n ) x=[0,n) x=[0,n) 即可。时间复杂度 O ( n 2 ) O(n^2) O(n2) 。
本题告诉我们,容斥思想是很灵活的,可以融入到 d p dp dp 转移中降低时间复杂度。
#include<bits/stdc++.h>
#define fi first
#define se second
#define ll long long
#define PII pair<int,int>
#define All(x) x.begin(),x.end()
#define INF 0x3f3f3f3f
using namespace std;
const int mx=5005;
const int mod=1e9+7;
int n,siz[mx];
ll dp[mx][mx],f[mx],tmp[mx],res;
vector<int> g[mx];
void dfs(int u,int fa) {
dp[u][1]=siz[u]=1;
for(auto v:g[u]) {
if(v==fa) continue;
dfs(v,u);
for(int i=0;i<=siz[u]+siz[v];i++) tmp[i]=0;
for(int i=siz[u];i>=0;i--) {
for(int j=siz[v];j>=0;j--) {
tmp[i]=(tmp[i]-dp[u][i]*dp[v][j]%mod*f[j]%mod)%mod;
tmp[i+j]=(tmp[i+j]+dp[u][i]*dp[v][j]%mod)%mod;
}
}
for(int i=0;i<=siz[u]+siz[v];i++) dp[u][i]=tmp[i];
siz[u]+=siz[v];
}
}
int main() {
// freopen("data.in","r",stdin);
scanf("%d",&n);
for(int i=1;i<n;i++) {
int u,v; scanf("%d%d",&u,&v);
g[u].push_back(v),g[v].push_back(u);
}
f[0]=1;
for(int i=2;i<=n;i++) {
f[i]=f[i-2]*(i-1)%mod;
}
dfs(1,0);
for(int i=0;i<=n;i++) {
res=(res+dp[1][i]*f[i]%mod)%mod;
}
if(res<0) res+=mod;
printf("%lld",res);
}