http://uoj.ac/problem/34
fft真是一个丧心病狂的东西
递归版
#include<cstdio>
#include<cmath>
#define FOR(i,s,t) for(register int i=s;i<=t;++i)
typedef double db;
const db pi=acos(-1);
const int N=500011;
struct complex{
db r,i;
typedef complex cp;
inline cp operator+(cp A)const{return (cp){r+A.r,i+A.i};}
inline cp operator-(cp A)const{return (cp){r-A.r,i-A.i};}
inline cp operator*(cp A){return (cp){r*A.r-i*A.i,r*A.i+i*A.r};}
}a[N],b[N];
typedef complex cp;
inline void fft(cp *x,int n,int type){
if(n==1)return;
int hf=n>>1;
cp l[hf+10],r[hf+10];
for(register int i=0;i<n;i+=2)
l[i>>1]=x[i],r[i>>1]=x[i+1];
fft(l,hf,type);fft(r,hf,type);
cp wn=(cp){cos(2*pi/n),sin(type*2*pi/n)},w=(cp){1,0},t;
for(register int i=0;i<hf;++i,w=w*wn)
t=w*r[i],x[i]=l[i]+t,x[i+hf]=l[i]-t;
}
int n,m,x;
int main(){
scanf("%d%d",&n,&m);
FOR(i,0,n)scanf("%d",&x),a[i].r=x;
FOR(i,0,m)scanf("%d",&x),b[i].r=x;
m+=n;for(n=1;n<=m;n<<=1);
fft(a,n,1);fft(b,n,1);
FOR(i,0,n)a[i]=a[i]*b[i];
fft(a,n,-1);
FOR(i,0,m)
printf("%d ",(int)(a[i].r/n+0.5));
return 0;
}
迭代版
#include<cstdio>
#include<cmath>
#include<algorithm>
#define gc getchar()
#define FOR(i,s,t) for(register int i=s;i<=t;++i)
using std::swap;
typedef double db;
const db pi=acos(-1);
struct complex{
db r,i;
typedef complex cp;
inline cp operator+(cp A)const{return (cp){r+A.r,i+A.i};}
inline cp operator-(cp A)const{return (cp){r-A.r,i-A.i};}
inline cp operator*(cp A)const{return (cp){r*A.r-i*A.i,r*A.i+A.r*i};}
}a[1<<18],b[1<<18],wn[18];
typedef complex cp;
int p[1<<18];
int n,m,lg2;
inline void fft(cp *a){
FOR(i,0,n-1)if(i<p[i])swap(a[i],a[p[i]]);
for(register int i=1,t=0;i<n;i<<=1,++t){
int m=i<<1;
cp w=wn[t];
for(register int j=0;j<n;j+=m){
cp v=(cp){1,0};
int e=i+j;
for(register int k=j;k<e;++k,v=v*w){
cp y=v*a[k+i];a[k+i]=a[k]-y;
a[k]=a[k]+y;
}
}
}
}
inline int read(){
char c;while(c=gc,c==' '||c=='\n');int data=c-48;
while(c=gc,c>='0'&&c<='9')data=(data<<1)+(data<<3)+c-48;return data;
}
int wr[51];
inline void write(int x){
if(!x){
putchar(48);
return;
}
while(x)wr[++wr[0]]=x%10,x/=10;
while(wr[0])putchar(48+wr[wr[0]--]);
}
int main(){
n=read();m=read();
FOR(i,0,n)a[i].r=1.00*read();
FOR(i,0,m)b[i].r=1.00*read();
m+=n;for(n=1;n<=m;n<<=1)++lg2;
FOR(i,0,n-1)p[i]=(p[i>>1]>>1)^((i&1)<<(lg2-1));
for(register int i=1,t=0;i<n;i<<=1,++t)wn[t]=(cp){cos(pi/i),sin(pi/i)};
fft(a);fft(b);
FOR(i,0,n-1)a[i]=a[i]*b[i];
for(register int i=1,t=0;i<n;i<<=1,++t)wn[t]=(cp){cos(pi/i),sin(-pi/i)};
fft(a);
FOR(i,0,m)write((int)(a[i].r/n+0.5)),putchar(' ');
return 0;
}
ntt
#include<cstdio>
#include<algorithm>
using namespace std;
const int mod=479<<21|1,maxn=1e6;
int a[maxn],b[maxn],p[maxn],s[maxn],gn[maxn];
int n,m,lg2,g,ny;
inline int fp(int a,int b){
int ret=1;
while(b){
if(b&1)ret=1ll*a*ret%mod;
a=1ll*a*a%mod;
b>>=1;
}
return ret;
}
inline int get_g(int p){
register int x=p-1;
for(register int i=2;i*i<=x;++i)
if(x%i==0){
while(x%i==0)x/=i;
s[++s[0]]=i;
}
if(x>1)s[++s[0]]=x;
for(register int i=2;;++i){
for(register int j=1;j<=s[0];++j)
if(fp(i,(p-1)/s[j])==1)goto die;
return i;
die:;
}
}
inline void ntt(int *a){
for(register int i=0;i<m;++i)
if(i<p[i])swap(a[i],a[p[i]]);
for(register int i=1,t=0,len,w,v;i<m;i<<=1,++t){
len=i<<1;
for(register int j=0;j<m;j+=len){
w=1;
for(register int k=j;k<i+j;++k,w=1ll*w*gn[t]%mod){
v=1ll*w*a[i+k]%mod;
a[i+k]=(a[k]-v+mod)%mod;
a[k]=(a[k]+v)%mod;
}
}
}
}
int main(){
g=get_g(mod);
scanf("%d%d",&n,&m);
for(register int i=0;i<=n;++i)scanf("%d",a+i);
for(register int i=0;i<=m;++i)scanf("%d",b+i);
n+=m;for(m=1;m<=n;m<<=1)++lg2;
for(register int i=0;i<m;++i)p[i]=(p[i>>1]>>1)^((i&1)<<(lg2-1));
for(register int i=1,t=0;i<m;i<<=1,++t)gn[t]=fp(g,(mod-1)/(i<<1));
ntt(a);ntt(b);
for(register int i=0;i<m;++i)a[i]=1ll*a[i]*b[i]%mod;
ntt(a);
reverse(a+1,a+m);
ny=fp(m,mod-2);
for(register int i=0;i<m;++i)a[i]=1ll*a[i]*ny%mod;
for(register int i=0;i<=n;++i)printf("%d ",a[i]);
return 0;
}
多项式求逆元
#include<cstdio>
#include<algorithm>
#include<cstring>
using namespace std;
const int mod=998244353,maxn=2e5+5;
int a[maxn],b[maxn],tmp[maxn],s[maxn],gn[maxn];
int n;
inline int fp(int a,int b){
int ret=1;
while(b){
if(b&1)ret=1ll*a*ret%mod;
a=1ll*a*a%mod;b>>=1;
}
return ret;
}
inline void ntt(int *a,int p,int f){
for(register int i=0;i<p;++i)
if(i<s[i])
swap(a[i],a[s[i]]);
for(register int i=1,t=0,g,w,v;i<p;i<<=1,++t){
g=gn[t];
for(register int j=0;j<p;j+=(i<<1)){
w=1;
for(register int k=j;k<i+j;++k,w=1ll*w*g%mod){
v=1ll*w*a[i+k]%mod;
a[i+k]=(a[k]-v+mod)%mod;
a[k]=(a[k]+v)%mod;
}
}
}
if(f==1)return;
reverse(a+1,a+p);
int ny=fp(p,mod-2);
for(register int i=0;i<p;++i)
a[i]=1ll*a[i]*ny%mod;
}
inline void solve(int *b,int deg){
if(deg==1){
b[0]=fp(a[0],mod-2);
return;
}
solve(b,(deg+1)>>1);
int p=1,lg2=0;while(p<(deg<<1))p<<=1,++lg2;
for(register int i=0;i<p;++i)tmp[i]=i<deg?a[i]:0;
for(register int i=((deg+1)>>1);i<p;++i)b[i]=0;
for(register int i=0;i<p;++i)s[i]=(s[i>>1]>>1)^((i&1)<<(lg2-1));
ntt(tmp,p,1),ntt(b,p,1);
for(register int i=0;i<p;++i)b[i]=(2ll*b[i]%mod-1ll*tmp[i]*b[i]%mod*b[i]%mod+mod)%mod;
ntt(b,p,-1);
}
int main(){
for(register int t=0,i=1;t<=20;i<<=1,++t)
gn[t]=fp(3,(mod-1)/(i<<1));
scanf("%d",&n);
for(register int i=0;i<=n;++i)scanf("%d",a+i);
solve(b,n+1);
for(register int i=0;i<=n;++i)printf("%d ",b[i]);
return 0;
}