题目链接:https://www.luogu.com.cn/problem/P6478
题意:给定一棵树,每次从里面选出各一个权值为
1
1
1和
0
0
0的结点,直到选完为止,求每次选出结点是另一个该次选出结点父亲的方案数
直接计算答案很难,但我们发现选出至少 i i i个结点满足题目要求比较容易算出来,再利用二项式反演容斥一下即可
设
d
p
[
i
]
[
j
]
dp[i][j]
dp[i][j]表示从
i
i
i及其子树中选出
j
j
j对题目所要求的点的方案数
d
p
[
i
]
[
j
]
=
∑
k
1
+
k
2
+
.
.
.
+
k
m
=
j
d
p
[
v
1
]
[
k
1
]
∗
d
p
[
v
2
]
[
k
2
]
∗
.
.
.
∗
d
p
[
v
m
]
[
k
m
]
dp[i][j]=\sum_{k_1+k_2+...+k_m=j} dp[v_1][k_1]*dp[v_2][k_2]*...*dp[v_m][k_m]
dp[i][j]=k1+k2+...+km=j∑dp[v1][k1]∗dp[v2][k2]∗...∗dp[vm][km]
其中
v
1
,
v
2
,
v
3
.
.
.
v
n
v_1,v_2,v_3...v_n
v1,v2,v3...vn是
i
i
i的子结点集合
考虑每个儿子结点对其父亲贡献,转移时让
d
p
[
i
]
[
j
]
dp[i][j]
dp[i][j]表示已经转移过的儿子对其的贡献(方案数)
那么在这次转移时只需乘上本次的贡献即可,即可表示所有子树的贡献
其实就是分步计算原先的
d
p
dp
dp方程式
for (register int j=0;j<=size[x]/2;++j){
if (!dp[x][j]) continue;
for (register int k=1;k<=size[v]/2;++k)
q[j+k]=(q[j+k]+dp[x][j]*dp[v][k])%Mod;
}
for (register int j=1;j<=size[x]/2+size[v]/2;++j){
dp[x][j]=(dp[x][j]+q[j])%Mod;
q[j]=0;
}
这玩意看上去总的复杂度是
O
(
n
3
)
O(n^3)
O(n3)
但其实是
O
(
n
2
)
O(n^2)
O(n2) 我目前还不会证
然后设
f
[
i
]
f[i]
f[i]为选出至少
i
i
i个结点满足题目要求的方案数
f
[
i
]
=
d
p
[
1
]
[
i
]
∗
(
n
−
i
)
!
f[i]=dp[1][i]*(n-i)!
f[i]=dp[1][i]∗(n−i)!
利用二项式反演的公式即可求解
C o d e Code Code
#include <bits/stdc++.h>
using namespace std;
#define int long long
const int MAXN=5000,Mod=998244353;
struct w{
int to,nx;
}head[MAXN+MAXN+10];
int a[MAXN+10],val[MAXN+10],cnt[MAXN+10],size[MAXN+10],q[MAXN+10],dp[MAXN+10][MAXN+10];
int f[MAXN+10],g[MAXN+10],p[MAXN+10],c[MAXN+10][MAXN+10],n,m;
inline int read();
inline void add(int,int,int);
void dfs(int,int);
signed main(){
//freopen ("std.in","r",stdin);
//freopen ("std.out","w",stdout);
n=read(); m=n/2;
p[0]=c[0][0]=1;
for (register int i=1;i<=n;++i){
char c=getchar();
while (!isdigit(c))c=getchar();
val[i]=c-'0',p[i]=p[i-1]*i%Mod;
}
for (register int i=1;i<=m;++i){
c[i][0]=1;
for (register int j=1;j<=i;++j)
c[i][j]=(c[i-1][j-1]+c[i-1][j])%Mod;
}
for (register int i=1;i<n;++i){
int x=read(),y=read();
add(x,y,i+i-1);
add(y,x,i+i);
}
dfs(1,0);
for (register int i=0;i<=m;++i) f[i]=dp[1][i]*p[m-i]%Mod;
for (register int i=0;i<=m;++i)
for (register int k=i;k<=m;++k){
int x=k-i&1 ? -1 : 1;
g[i]+=x*f[k]*c[k][i];
g[i]=(g[i]%Mod+Mod)%Mod;
}
for (register int i=0;i<=m;++i)
printf("%lld\n",g[i]);
return 0;
}
inline int read(){
int x=0;
char c=getchar();
while (!isdigit(c))c=getchar();
while (isdigit(c))x=(x<<1)+(x<<3)+(c&15),c=getchar();
return x;
}
inline void add(int x,int y,int i){head[i].to=y;head[i].nx=a[x];a[x]=i;}
void dfs(int x,int fa){
size[x]=dp[x][0]=1;
for (register int i=a[x];i;i=head[i].nx){
int v=head[i].to;
if (v==fa) continue;
dfs(v,x);
for (register int j=0;j<=size[x]/2;++j){
if (!dp[x][j]) continue;
for (register int k=1;k<=size[v]/2;++k)
q[j+k]=(q[j+k]+dp[x][j]*dp[v][k])%Mod;
}
for (register int j=1;j<=size[x]/2+size[v]/2;++j){
dp[x][j]=(dp[x][j]+q[j])%Mod;
q[j]=0;
}
q[0]=0;size[x]+=size[v],cnt[x]+=cnt[v];
}
for (register int i=size[x];i>=1;--i){
if (val[x]) dp[x][i]+=dp[x][i-1]*max(cnt[x]-i+1,(long long)0);
else dp[x][i]+=dp[x][i-1]*max(size[x]-cnt[x]-i,(long long)0);
dp[x][i]%=Mod;
}
if (!val[x]) ++cnt[x];
}