1.多项式求逆
【代码】
#include<bits/stdc++.h>
using namespace std;
#define mp make_pair
#define fi first
#define se second
#define lson now<<1
#define rson now<<1|1
typedef long long ll;
const int mod=998244353;
const int maxn=8e5+5;
int a[maxn],b[maxn],c[maxn],n;
int qpow(int a,int b)
{
int res=1;
while(b)
{
if(b&1) res=1LL*res*a%mod;
a=1LL*a*a%mod;
b>>=1;
}
return res;
}
int rev[maxn];
void ntt(int *A,int lim,int op)
{
for(int i=0;i<lim;i++)
if(rev[i]>i) swap(A[i],A[rev[i]]);
for(int i=1;i<lim;i<<=1)
{
int Wn=qpow(3,(mod-1)/(i<<1));
if(op==-1) Wn=qpow(Wn,mod-2);
for(int P=i<<1,j=0;j<lim;j+=P)
{
int w=1;
for(int k=0;k<i;k++,w=1LL*w*Wn%mod)
{
int x=A[j+k],y=1LL*A[j+k+i]*w%mod;
A[j+k]=(x+y)%mod;
A[j+k+i]=(x-y+mod)%mod;
}
}
}
if(op==-1)
{
int inv=qpow(lim,mod-2);
for(int i=0;i<lim;i++)
A[i]=1LL*A[i]*inv%mod;
}
}
void work(int dep,int *a,int *b)
{
if(dep==1)
{
b[0]=qpow(a[0],mod-2);
return;
}
work((dep+1)>>1,a,b);
int lim=1; int l=0;
while(lim<(dep+dep)) lim<<=1,l++;
for(int i=1;i<lim;i++) rev[i]=(rev[i>>1]>>1)|((i&1)<<(l-1));
for(int i=0;i<dep;i++) c[i]=a[i];
for(int i=dep;i<lim;i++) c[i]=0;
ntt(c,lim,1); ntt(b,lim,1);
for(int i=0;i<lim;i++)
b[i]=1LL*(2-1LL*b[i]*c[i]%mod+mod)%mod*b[i]%mod;
ntt(b,lim,-1);
for(int i=dep;i<lim;i++) b[i]=0;
}
int main()
{
freopen("a.in","r",stdin);
freopen("a.out","w",stdout);
scanf("%d",&n);
for(int i=0;i<n;i++) scanf("%d",&a[i]);
work(n,a,b);
for(int i=0;i<n;i++) printf("%d ",b[i]);
return 0;
}
2.多项式开方
【代码】
#include<bits/stdc++.h>
using namespace std;
#define mp make_pair
#define fi first
#define se second
#define lson now<<1
#define rson now<<1|1
typedef long long ll;
const int maxn=4e5+5;
const int mod=998244353;
int n,f[maxn],ans[maxn],A[maxn],B[maxn],C[maxn],D[maxn];
int rev[maxn],inv2;
int qpow(int a,int b)
{
int res=1;
while(b)
{
if(b&1) res=1LL*res*a%mod;
a=1LL*a*a%mod;
b>>=1;
}
return res;
}
void ntt(int *A,int lim,int op)
{
for(int i=0;i<lim;i++)
if(i<rev[i]) swap(A[i],A[rev[i]]);
for(int i=1;i<lim;i<<=1)
{
int Wn=qpow(3,(mod-1)/(i<<1));
if(op==-1) Wn=qpow(Wn,mod-2);
for(int P=i<<1,j=0;j<lim;j+=P)
{
int w=1;
for(int k=0;k<i;k++,w=1LL*w*Wn%mod)
{
int x=A[j+k],y=1LL*A[j+k+i]*w%mod;
A[j+k]=(x+y)%mod;
A[j+k+i]=(x-y+mod)%mod;
}
}
}
if(op==-1)
{
int inv=qpow(lim,mod-2);
for(int i=0;i<lim;i++)
A[i]=1LL*A[i]*inv%mod;
}
}
void getinv(int *a,int *b,int lim)
{
b[0]=qpow(a[0],mod-2);
int limit;
for(int i=1;i<(lim+lim);i<<=1)
{
limit=i<<1;
for(int j=0;j<i;j++)
A[j]=a[j],B[j]=b[j];
for(int j=0;j<limit;j++)
rev[j]=(rev[j>>1]>>1)|((j&1)?i:0);
ntt(B,limit,1); ntt(A,limit,1);
for(int j=0;j<limit;j++)
b[j]=1LL*(2ll-1LL*A[j]*B[j]%mod+mod)%mod*B[j]%mod;
ntt(b,limit,-1);
for(int j=i;j<limit;j++) b[j]=0;
}
for(int i=0;i<limit;i++) B[i]=A[i]=0;
for(int i=lim;i<limit;i++) b[i]=0;
}
void getsqrt(int *a,int *b,int lim)
{
b[0]=1;
int limit;
int *A=C,*B=D;
for(int i=1;i<(lim+lim);i<<=1)
{
limit=i<<1;
for(int j=0;j<i;j++) A[j]=a[j];
getinv(b,B,i);
for(int j=1;j<limit;j++)
rev[j]=(rev[j>>1]>>1)|((j&1)?i:0);
ntt(A,limit,1); ntt(B,limit,1);
for(int j=0;j<limit;j++) A[j]=1LL*A[j]*B[j]%mod;
ntt(A,limit,-1);
for(int j=0;j<lim;j++) b[j]=1LL*(b[j]+A[j])%mod*inv2%mod;
for(int j=lim;j<limit;j++) b[j]=0;
}
for(int i=0;i<limit;i++) A[i]=B[i]=0;
for(int i=lim;i<limit;i++) b[i]=0;
}
int main()
{
freopen("a.in","r",stdin);
freopen("a.out","w",stdout);
scanf("%d",&n); inv2=qpow(2,mod-2);
for(int i=0;i<n;i++) scanf("%d",&f[i]);
getsqrt(f,ans,n);
for(int i=0;i<n;i++) printf("%d ",ans[i]);
return 0;
}
3.多项式ln,exp
【代码】
#include<bits/stdc++.h>
using namespace std;
#define mp make_pair
#define fi first
#define se second
#define lson now<<1
#define rson now<<1|1
typedef long long ll;
const int mod=998244353;
const int maxn=8e5+5;
int a[maxn],b[maxn],c[maxn],n;
int ga[maxn],gb[maxn],ans[maxn],f[maxn];
int qpow(int a,int b)
{
int res=1;
while(b)
{
if(b&1) res=1LL*res*a%mod;
a=1LL*a*a%mod;
b>>=1;
}
return res;
}
int rev[maxn];
void ntt(int *A,int lim,int op)
{
for(int i=0;i<lim;i++)
if(rev[i]>i) swap(A[i],A[rev[i]]);
for(int i=1;i<lim;i<<=1)
{
int Wn=qpow(3,(mod-1)/(i<<1));
if(op==-1) Wn=qpow(Wn,mod-2);
for(int P=i<<1,j=0;j<lim;j+=P)
{
int w=1;
for(int k=0;k<i;k++,w=1LL*w*Wn%mod)
{
int x=A[j+k],y=1LL*A[j+k+i]*w%mod;
A[j+k]=(x+y)%mod;
A[j+k+i]=(x-y+mod)%mod;
}
}
}
if(op==-1)
{
int inv=qpow(lim,mod-2);
for(int i=0;i<lim;i++)
A[i]=1LL*A[i]*inv%mod;
}
}
void getinv(int dep,int *a,int *b)
{
if(dep==1)
{
b[0]=qpow(a[0],mod-2);
return;
}
getinv((dep+1)>>1,a,b);
int lim=1; int l=0;
while(lim<(dep+dep)) lim<<=1,l++;
for(int i=1;i<lim;i++) rev[i]=(rev[i>>1]>>1)|((i&1)<<(l-1));
for(int i=0;i<dep;i++) c[i]=a[i];
for(int i=dep;i<lim;i++) c[i]=0;
ntt(c,lim,1); ntt(b,lim,1);
for(int i=0;i<lim;i++)
b[i]=1LL*(2-1LL*b[i]*c[i]%mod+mod)%mod*b[i]%mod;
ntt(b,lim,-1);
for(int i=dep;i<lim;i++) b[i]=0;
}
void getdao(int *a,int *b,int lim)
{
for(int i=0;i<lim;i++)
b[i-1]=1LL*i*a[i]%mod;
b[lim-1]=0;
}
void jifen(int *a,int *b,int lim)
{
for(int i=1;i<lim;i++)
b[i]=1LL*a[i-1]*qpow(i,mod-2)%mod;
b[0]=0;
}
void work(int *a,int *b,int lim)
{
getdao(a,ga,lim); getinv(lim,a,gb);
int limit=1; int l=0;
while(limit<(lim+lim)) limit<<=1,l++;
for(int i=1;i<limit;i++) rev[i]=(rev[i>>1]>>1)|((i&1)<<(l-1));
ntt(ga,limit,1); ntt(gb,limit,1);
for(int i=0;i<limit;i++)
a[i]=1LL*ga[i]*gb[i]%mod;
ntt(a,limit,-1);
jifen(a,b,lim);
}
int main()
{
scanf("%d",&n);
for(int i=0;i<n;i++) scanf("%d",&f[i]);
int len=1;
while(len<n) len<<=1;
work(f,ans,len);
for(int i=0;i<n;i++) printf("%d ",ans[i]);
return 0;
}
事实证明牛顿迭代虽然看起来比分治NTT少了了个log,但是由于常数较大?导致我的牛顿迭代比分治还慢
【代码】
#include<bits/stdc++.h>
using namespace std;
#define mp make_pair
#define fi first
#define se second
#define lson now<<1
#define rson now<<1|1
typedef long long ll;
const int mod=998244353;
const int maxn=8e5+5;
int c[maxn],n,lnb[maxn];
int ga[maxn],gb[maxn],ans[maxn],f[maxn];
int qpow(int a,int b)
{
int res=1;
while(b)
{
if(b&1) res=1LL*res*a%mod;
a=1LL*a*a%mod;
b>>=1;
}
return res;
}
int rev[maxn];
void ntt(int *A,int lim,int op)
{
for(int i=0;i<lim;i++)
if(rev[i]>i) swap(A[i],A[rev[i]]);
for(int i=1;i<lim;i<<=1)
{
int Wn=qpow(3,(mod-1)/(i<<1));
if(op==-1) Wn=qpow(Wn,mod-2);
for(int P=i<<1,j=0;j<lim;j+=P)
{
int w=1;
for(int k=0;k<i;k++,w=1LL*w*Wn%mod)
{
int x=A[j+k],y=1LL*A[j+k+i]*w%mod;
A[j+k]=(x+y)%mod;
A[j+k+i]=(x-y+mod)%mod;
}
}
}
if(op==-1)
{
int inv=qpow(lim,mod-2);
for(int i=0;i<lim;i++)
A[i]=1LL*A[i]*inv%mod;
}
}
void getinv(int dep,int *a,int *b)
{
if(dep==1)
{
b[0]=qpow(a[0],mod-2);
return;
}
getinv((dep+1)>>1,a,b);
int lim=1; int l=0;
while(lim<(dep+dep)) lim<<=1,l++;
for(int i=1;i<lim;i++) rev[i]=(rev[i>>1]>>1)|((i&1)<<(l-1));
for(int i=0;i<dep;i++) c[i]=a[i];
for(int i=dep;i<lim;i++) c[i]=0;
ntt(c,lim,1); ntt(b,lim,1);
for(int i=0;i<lim;i++)
b[i]=1LL*(2-1LL*b[i]*c[i]%mod+mod)%mod*b[i]%mod;
ntt(b,lim,-1);
for(int i=dep;i<lim;i++) b[i]=0;
}
void getdao(int *a,int *b,int lim)
{
for(int i=1;i<lim;i++)
b[i-1]=1LL*i*a[i]%mod;
b[lim-1]=0;
}
void jifen(int *a,int *b,int lim)
{
for(int i=1;i<lim;i++)
b[i]=1LL*a[i-1]*qpow(i,mod-2)%mod;
b[0]=0;
}
int ta[maxn];
void getln(int *a,int *b,int lim)
{
for(int i=0;i<(lim<<2);i++) gb[i]=0;
getdao(a,ga,lim); getinv(lim,a,gb);
int limit=1; int l=0;
while(limit<(lim+lim)) limit<<=1,l++;
for(int i=1;i<limit;i++) rev[i]=(rev[i>>1]>>1)|((i&1)<<(l-1));
for(int i=lim-1;i<limit;i++) ga[i]=0;
ntt(ga,limit,1); ntt(gb,limit,1);
for(int i=0;i<limit;i++)
ta[i]=1LL*ga[i]*gb[i]%mod;
ntt(ta,limit,-1);
jifen(ta,b,lim);
for(int i=lim;i<limit;i++) b[i]=0;
b[0]=0;
}
void getexp(int *a,int *b,int lim)
{
if(lim==1)
{
b[0]=1;
return;
}
getexp(a,b,(lim+1)>>1);
getln(b,lnb,lim);
int limit=1,l=0;
while(limit<(lim+lim)) limit<<=1,l++;
for(int i=1;i<limit;i++) rev[i]=(rev[i>>1]>>1)|((i&1)<<(l-1));
for(int i=0;i<lim;i++) lnb[i]=(a[i]-lnb[i]+mod)%mod;
for(int i=lim;i<limit;i++) lnb[i]=b[i]=0;
lnb[0]++;
ntt(lnb,limit,1); ntt(b,limit,1);
for(int i=0;i<limit;i++)
b[i]=1LL*b[i]*lnb[i]%mod;
ntt(b,limit,-1);
for(int i=lim;i<limit;i++) b[i]=0;
}
int main()
{
scanf("%d",&n);
for(int i=0;i<n;i++) scanf("%d",&f[i]);
getexp(f,ans,n);
for(int i=0;i<n;i++) printf("%d ",ans[i]);
return 0;
}