这题是我搜NTT搜到的,当时就看到“多项式开根”这样的标题,于是找到了L-leader的博客,补了下幂级数的东西,用两节数学课学会了。
我再看题解,好像都是教我怎么开方,求逆的,然后又拖了几天。终于昨晚睡不着,突然就想到了。。。
先介绍一下生成函数。
简单的说,就是一个数组a[0..n],可以生成一个多项式函数(幂级数)
f(x)=∑i=0na[i]∗xi
题意是给你二叉树每个节点可能的点权集合C,元素都<=1e5,对于所有1<=s<=m,有种不同的二叉树满足点权和为s,答案模一个费马素数。
设g[i]为一个01数组,表示i是否在C出现,f[i]为权值和为i的方案数,即是答案。F为f的生成函数,G为g的生成函数,根据题意g[0]=0,因为存在空树,f[0]=1;
我们可以枚举二叉树根的权值,剩下左右儿子为子问题,就有
f[x]=∑i=0xg[i]∑j=0x−if[j]∗f[x−i−j]
就可以大概知道 F=F2G 。
根据f[0]=1,g[1]=0,所以有 F=F2G+1
通过解一元二次方程,再结合f[0]=1,g[1]=0
F=21+1−4G−−−−−−√
然后就是多项式求逆和多项式开根了。
#include <iostream>
#include <fstream>
#include <algorithm>
#include <cmath>
#include <ctime>
#include <cstdio>
#include <cstdlib>
#include <cstring>
using namespace std;
#define mmst(a, b) memset(a, b, sizeof(a))
#define mmcp(a, b) memcpy(a, b, sizeof(b))
typedef long long LL;
const int p=998244353,I2=499122177;
const int N=800400;
int cheng(int a,int b)
{
int res=1;
for(;b;b>>=1,a=(LL)a*a%p)
if(b&1)
res=(LL)res*a%p;
return res;
}
int n,rev[N];
void init(int lim)
{
n=1;
int k=-1;
while(n<lim)
n<<=1,k++;
for(int i=0;i<n;i++)
rev[i]=(rev[i>>1] >> 1) | ((i&1)<<k);
}
void ntt(int *a,int ops)
{
for(int i=0;i<n;i++)
if(i<rev[i])
swap(a[i],a[rev[i]]);
for(int l=2;l<=n;l<<=1)
{
int m=l>>1,wn;
if(ops)
wn=cheng(3,(p-1)/l);
else
wn=cheng(3,p-1-(p-1)/l);
for(int i=0;i<n;i+=l)
{
int w=1;
for(int k=0;k<m;k++)
{
int t=(LL)a[i+k+m]*w%p;
a[i+k+m]=(a[i+k]-t+p)%p;
a[i+k]=(a[i+k]+t)%p;
w=(LL)w*wn%p;
}
}
}
if(!ops)
{
int Inv=cheng(n,p-2);
for(int i=0;i<n;i++)
a[i]=(LL)a[i]*Inv%p;
}
}
int g[N];
int mx=1,by,nn,mm;
int X[N],Y[N],sqr[N],A[N],B[N],C[N];
void Inverse(int *a,int *b,LL len)
{
if(len==1)
{
b[0]=cheng(a[0],p-2);
return;
}
Inverse(a,b,len>>1);
init(2*len);
for(int i=0;i<len;i++)
X[i]=a[i];
for(int i=0;i<(len>>1);i++)
Y[i]=b[i];
ntt(X,1);
ntt(Y,1);
for(int i=0;i<n;i++)
X[i]=(2ll*Y[i]%p-(LL)X[i]*Y[i]%p*Y[i]%p+p)%p;
ntt(X,0);
for(int i=0;i<n;i++)
{
if(i>=len)
b[i]=0;
else
b[i]=X[i];
X[i]=Y[i]=0;
}
}
void Sqrt(int len)
{
if(len==1)
{
sqr[0]=1;//本题被开方的多项式常数项为1
return;
}
Sqrt(len>>1);
Inverse(sqr,A,len);
for(int i=0;i<(len>>1);i++)
B[i]=sqr[i];
for(int i=0;i<len;i++)
C[i]=g[i];
init(len*2);
ntt(A,1);
ntt(B,1);
ntt(C,1);
for(int i=0;i<n;i++)
A[i]=(1ll*C[i]+(LL)B[i]*B[i])%p*I2%p*A[i]%p;
ntt(A,0);
for(int i=0;i<n;i++)
{
sqr[i]=A[i];
if(i>=len)
sqr[i]=0;
A[i]=B[i]=C[i]=0;
}
}
int main()
{
cin>>nn>>mm;
while(mx<=mm)
mx<<=1;
for(int i=1;i<=nn;i++)
{
scanf("%d",&by);
if(by<=mm)
g[by]=1;
}
for(int i=0;i<mx;i++)
if(g[i])
g[i]=p-4;
g[0]=1;
Sqrt(mx);
sqr[0]=(sqr[0]+1)%p;
mmst(g,0);
Inverse(sqr,g,mx);
for(int i=1;i<=mm;i++)
printf("%d\n",(g[i]+g[i])%p);
return 0;
}