总结啥的就放到多项式入门里了,好多细节需要注意~
code:
#include <bits/stdc++.h>
#define ll long long
#define setIO(s) freopen(s".in","r",stdin)
using namespace std;
const int mod=998244353,G=3,N=1000003;
int A[N],B[N],f[N],g[N],inv2,C[N],D[N];
inline int qpow(int x,int y)
{
int tmp=1;
for(;y;y>>=1,x=(ll)x*x%mod) if(y&1) tmp=1ll*tmp*x%mod;
return tmp;
}
inline int INV(int x) { return qpow(x,mod-2); }
void NTT(int *a,int len,int flag)
{
int i,j,k,mid;
for(i=k=0;i<len;++i)
{
if(i>k) swap(a[i],a[k]);
for(j=len>>1;(k^=j)<j;j>>=1);
}
for(mid=1;mid<len;mid<<=1)
{
int wn=qpow(G,(mod-1)/(mid<<1));
if(flag==-1) wn=INV(wn);
for(i=0;i<len;i+=mid<<1)
{
int w=1;
for(j=0;j<mid;++j)
{
int x=a[i+j], y=1ll*w*a[i+j+mid]%mod;
a[i+j]=1ll*(x+y)%mod, a[i+j+mid]=1ll*(x-y+mod)%mod;
w=1ll*w*wn%mod;
}
}
}
if(flag==-1)
{
int rev=INV(len);
for(i=0;i<len;++i) a[i]=1ll*a[i]*rev%mod;
}
}
void getinv(int *a,int *b,int len)
{
if(len==1) { b[0]=INV(a[0]); return; }
getinv(a,b,len>>1);
int i,j;
for(i=0;i<(len<<1);++i) C[i]=D[i]=0;
for(i=0;i<len;++i) C[i]=a[i], D[i]=b[i];
NTT(C,len<<1,1);
NTT(D,len<<1,1);
for(i=0;i<(len<<1);++i) C[i]=1ll*C[i]*D[i]%mod*D[i]%mod;
NTT(C,len<<1,-1);
for(i=0;i<len;++i) b[i]=((b[i]<<1)%mod-C[i]+mod)%mod;
}
void getsqrt(int *a,int *b,int len)
{
if(len==1) { b[0]=1; return; }
getsqrt(a,b,len>>1);
int i,j;
for(i=0;i<(len<<1);++i) A[i]=B[i]=0;
getinv(b,B,len);
for(i=0;i<len;++i) A[i]=a[i];
NTT(A,len<<1,1);
NTT(B,len<<1,1);
for(i=0;i<(len<<1);++i) A[i]=1ll*A[i]*B[i]%mod;
NTT(A,len<<1,-1);
for(i=0;i<len;++i) b[i]=1ll*(b[i]+A[i])%mod*inv2%mod;
}
int main()
{
// setIO("input");
int n,i,j,lim=1;
inv2=INV(2);
scanf("%d",&n);
for(i=0;i<n;++i) scanf("%d",&f[i]);
while(lim<n) lim<<=1;
getsqrt(f,g,lim);
for(i=0;i<n;++i) printf("%d ",g[i]);
return 0;
}