题目
Description
小A和小B进行一个游戏,在一棵n个点,以1为根的树中,两人各占其中一半的点,每个回合他们会拿出一个之前没用过的点,如果对手的点在自己的点的子树内,则该回合获胜。如果自己的点在对方的子树内,该回合失败。其他情况视为平局。当所有点都被用完时游戏结束。
作为旁观者的你只想知道当他们随机选点的情况下期望有几个回合能决出胜负,即非平局的回合数的期望。
为了计算这个期望,你决定对于k=0~n,计算出恰好有k个回合能决出胜负的方案数,两种方案不同当且仅当存在一个属于小A的点,他被使用的那一个回合小B使用的点不同。答案对998244353取模。
Input
第一行一个正整数n,保证n是偶数。
第二行一个长度为n的01串,第i个字符是0表示点i属于小A,否则点i属于小B。保证0和1的个数相同。
接下来n-1行,每行两个正整数表示一条边的两个端点。
Output
一共一行n/2+1个整数,第i个数表示恰好有i-1个回合能决出胜负的方案数,答案对998244353取模。
Sample Input
8
10010011
1 2
1 3
2 4
2 5
5 6
3 7
3 8
Sample Output
0 10 10 4 0
Data Constraint
思路
考试的时候理解错题意了
首先可以写一个n6的暴力,记f[i][j][k]表示在i的子树内有j对互相匹配,其中有k对是祖孙关系的方案数。
合并两个子树的信息:枚举两个子树内匹配了的对数以及祖孙关系的对数,再枚举这两个子树之间互相匹配的黑点和白点的数量。乘上相应的组合数。
前面的动态规划的优化瓶颈在于不为祖孙关系的点对的方案数统计很复杂,如果能只统计互为祖孙关系的点对将会非常好处理。
记状态f[i][j]表示i的子树中选取不重复的j对互为祖孙关系的点对的方案数。
这个dp的复杂度分析可以参考树形依赖背包,是O(n^2)的
我们只关心这j对一定是祖孙关系,而其他的点可以任意匹配。这里只需要在最后乘上一个阶乘就好了。
记g[i]=f[1][i]*((n/2-i)!),ans[i]为恰好i对的方案数。
在这种情况下,在这j对点之外还可能有其他的点对为祖孙关系,我们需要去掉这些情况。
事实上,一个存在x对祖孙关系的方案会被g[i]统计C(x,i)次,即g[i]=∑C(x,i)*ans[x]
移项得ans[i]=g[i]-∑C(x,i)*ans[x] (i≠x)
我们知道当x<i时C(x,i)=0,所以只需要从大到小枚举i就可以求出全部的答案。
时间复杂度O(n^2)
代码
#include<bits/stdc++.h>
#define N 5077
#define ll long long
#define mod 998244353
using namespace std;
int n,bz[N];
int cnt,e[N*2],nx[N*2],ls[N];
ll sz[N][2],f[N][N/2],fct[N],ans,C[N/2][N/2];
char ch;
void ins(int x,int y){
cnt++; e[cnt]=y; nx[cnt]=ls[x]; ls[x]=cnt;
cnt++; e[cnt]=x; nx[cnt]=ls[y]; ls[y]=cnt;
}
void dfs(int x,int p){
for(int i=ls[x]; i; i=nx[i]) if(e[i]!=p)
dfs(e[i],x);
f[x][0]=1;
for(int i=ls[x]; i; i=nx[i]) if(e[i]!=p){
int y=e[i];
for(int j=min(sz[x][0],sz[x][1]); j>=0; j--)
for(int k=1; k<=min(sz[y][0],sz[y][1]); k++)
(f[x][j+k]+=f[x][j]*f[y][k])%=mod;
sz[x][0]+=sz[y][0],sz[x][1]+=sz[y][1];
}
for(int j=min(sz[x][0],sz[x][1]); j>=0; j--)
(f[x][j+1]+=f[x][j]*(sz[x][bz[x]^1]-j))%=mod;
sz[x][bz[x]]++;
}
int main()
{
freopen("match.in","r",stdin); freopen("match.out","w",stdout);
scanf("%d",&n);
for(ch=getchar();ch<'0'||ch>'1';ch=getchar());
for(int i=1; i<=n; i++) bz[i]=ch-'0',ch=getchar();
fct[0]=1;for(int i=1; i<=n; i++) fct[i]=fct[i-1]*i%mod;
for(int i=1,x,y; i<n; i++) scanf("%d%d",&x,&y),ins(x,y);
C[0][0]=1;
for(int i=1; i<=n/2; i++)
{
C[i][0]=1;
for(int j=1; j<=i; j++)
C[i][j]=(C[i-1][j]+C[i-1][j-1])%mod;
}
dfs(1,0);
for(int i=0; i<=n/2; i++) f[1][i]=f[1][i]*fct[n/2-i]%mod;
for(int i=0; i<=n/2; i++)
{
ans=0;
for(int j=i; j<=n/2; j++) ans+=f[1][j]*C[j][i]*(((j-i)&1)?-1:1)%mod;
printf("%lld ",(ans%mod+mod)%mod);
}
}