题意
对于 1 到 m m m 的整数 i i i,统计有多少个带权二叉树,点权权值属于给定集合 C C C,点权和为 i i i。答案模 998244353 998244353 998244353。 ∣ C ∣ , m ≤ 1 0 5 |C|,m\leq 10^5 ∣C∣,m≤105。
题解
单个节点的生成函数 F ( x ) = ∑ i = 1 [ i ∈ C ] x i F(x)=\sum\limits_{i=1}[i\in C]x^i F(x)=i=1∑[i∈C]xi。
一个二叉树是由左子树、右子树、根拼起来的,即: G = F ⋅ G 2 G=F\cdot G^2 G=F⋅G2。
解方程可得 G = 1 ± 1 − 4 F 2 F G=\dfrac{1\pm\sqrt{1-4F}}{2F} G=2F1±1−4F。
F F F 的常数项为 0 0 0, 1 − 4 F \sqrt{1-4F} 1−4F 的常数项为 1, G G G 的常数项为 1,所以取 G = 1 − 1 − 4 F 2 F = 2 1 + 1 − 4 F G=\dfrac{1-\sqrt{1-4F}}{2F}=\dfrac{2}{1+\sqrt{1-4F}} G=2F1−1−4F=1+1−4F2。
多项式求逆+开方即可。
代码:
/**********
Author: WLBKR5
Problem: codeforces 438E, bzoj 3625
Name: The Child and Binary Tree, 小朋友和二叉树
Source: Codeforces Round #250 (438)
Algorithm: 生成函数
Date: 2020/06/04
Statue: accepted
Submission: codeforces.com/contest/438/submission/82466147, darkbzoj.tk/submission/68524,
**********/
#include<bits/stdc++.h>
using namespace std;
int getint(){
int ans=0,f=1;
char c=getchar();
while(c<'0'||c>'9'){
if(c=='-')f=-1;
c=getchar();
}
while(c>='0'&&c<='9'){
ans=ans*10+c-'0';
c=getchar();
}
return ans;
}
const int N=4e5+10,mod=998244353,G=3,inv2=(mod+1)>>1;
int f[N];
int qpow(int x,int y){
int ans=1;
while(y){
if(y&1)ans=ans*1ll*x%mod;
x=x*1ll*x%mod;
y>>=1;
}
return ans;
}
int w[N],iw[N],maxn;
int rev[N];
void init_w(int n){
maxn=n;
w[0]=iw[0]=1;
w[1]=qpow(G,(mod-1)/n);
iw[1]=qpow(w[1],mod-2);
for(int i=2;i<n;i++)
w[i]=w[i-1]*1ll*w[1]%mod,
iw[i]=iw[i-1]*1ll*iw[1]%mod;
}
void init_rev(int n,int l){
for(int i=1;i<n;i++)
rev[i]=(rev[i>>1]>>1)|((i&1)<<(l-1));
}
void ntt(int *a,int n,int r=1){
for(int i=0;i<n;i++)if(rev[i]>i)swap(a[rev[i]],a[i]);
int *W=(~r?w:iw);
for(int i=1;i<n;i<<=1){
int d=maxn/i/2;
for(int j=0;j<n;j+=i*2){
int t=0;
for(int k=0;k<i;k++,t+=d){
int x=a[j+k],y=a[j+k+i]*1ll*W[t]%mod;
a[j+k]=(x+y)%mod;
a[j+k+i]=(x-y+mod)%mod;
}
}
}
if(!~r){
int invn=qpow(n,mod-2);
for(int i=0;i<n;i++)a[i]=a[i]*1ll*invn%mod;
}
}
int f_[N],g[N],h[N];
void get_inv(int *f,int n){
if(n==1){
g[0]=qpow(f[0],mod-2);
return;
}
get_inv(f,(n+1)>>1);
memcpy(h,g,sizeof(int)*n);
int nn=1,l=0;
while(nn<=n*2)nn<<=1,++l;
init_rev(nn,l);
for(int i=0;i<n;i++)f_[i]=f[i];
for(int i=0;i<n;i++)g[i]=h[i]*2%mod;
for(int i=n;i<nn;i++)f_[i]=g[i]=h[i]=0;
ntt(f_,nn);
ntt(h,nn);
for(int i=0;i<nn;i++)h[i]=h[i]*1ll*h[i]%mod*f_[i]%mod;
ntt(h,nn,-1);
for(int i=0;i<n;i++)g[i]=(g[i]-h[i]+mod)%mod;
//for(int i=0;i<n;i++)cerr<<"."<<g[i];cerr<<endl;
}
int sf_[N],sg[N],sh[N];
void get_sqrt(int *sf,int n){
if(n==1){
sg[0]=1;
return;
}
get_sqrt(sf,(n+1)>>1);
memcpy(sh,sg,sizeof(int)*n);
for(int i=0;i<n;i++){
sg[i]=sh[i]*1ll*inv2%mod;
sh[i]=sh[i]*2%mod;
}
get_inv(sh,n);
int nn=1,l=0;
while(nn<=n*2)nn<<=1,++l;
init_rev(nn,l);
for(int i=0;i<n;i++)sf_[i]=sf[i];
for(int i=n;i<nn;i++)sf_[i]=0;
ntt(g,nn);
ntt(sf_,nn);
for(int i=0;i<nn;i++)sh[i]=g[i]*1ll*sf_[i]%mod;
ntt(sh,nn,-1);
for(int i=0;i<n;i++)sg[i]=(sg[i]+sh[i])%mod;
//for(int i=0;i<n;i++)cerr<<" "<<sg[i];cerr<<endl;
}
int main(){
int n=getint(),m=getint()+1;
int mm=1;while(mm<=m*2)mm<<=1;init_w(mm);
for(int i=0;i<n;i++)f[getint()]=mod-4;
f[0]=1;
//for(int i=0;i<m;i++)cerr<<"> "<<f[i];cerr<<endl;
get_sqrt(f,m);
//for(int i=0;i<m;i++)cerr<<"> "<<sg[i];cerr<<endl;
sg[0]++;
get_inv(sg,m);
//for(int i=0;i<m;i++)cerr<<"> "<<g[i];cerr<<endl;
for(int i=0;i<m;i++){
g[i]=g[i]*2%mod;
}
for(int i=1;i<m;i++){
printf("%d\n",g[i]);
}
return 0;
}