1.核心:
FFT:
正常版本:
#include<bits/stdc++.h>
#define maxn 400005
using namespace std;
const double PI = acos(-1);
struct cplx
{
double r,i;
cplx(double r=0,double i=0):r(r),i(i){
}
cplx operator +(const cplx &B)const{
return cplx(r+B.r,i+B.i); }
cplx operator -(const cplx &B)const{
return cplx(r-B.r,i-B.i); }
cplx operator *(const cplx &B)const{
return cplx(r*B.r-i*B.i,i*B.r+r*B.i); }
cplx conj(){
return cplx(r,-i); }
}a[maxn],b[maxn];
int r[maxn]={
};
cplx w[maxn] = {
1};
inline void FFT(cplx *A,int lgn,int tp)
{
int n = 1<<lgn;
for(int i=1;i<n;i++) r[i] = (r[i>>1]>>1) | ((i&1)<<(lgn-1));
for(int i=1;i<n;i++) if(i < r[i]) swap(A[i] , A[r[i]]);
for(int len=2;len<=n;len<<=1){
int l = len >> 1;cplx wn(cos(PI / l) , sin(PI / l) * tp);
for(int i=1;i<l;i++) w[i] = w[i-1] * wn;
for(int st = 0;st < n;st += len) for(int k=0;k<l;k++)
{
cplx tmp = w[k] * A[st + k + l];
A[st + k + l] = A[st + k] - tmp , A[st + k] = A[st + k] + tmp;
}
}
if(tp==-1) for(int i=0;i<n;i++) A[i].r /= n;
}
int main()
{
int n,m;
scanf("%d%d",&n,&m);
for(int i=0;i<=n;i++) scanf("%lf",&a[i].r);
for(int i=0;i<=m;i++) scanf("%lf",&a[i].i);
n++,m++;
int len = 0;
for(;n+m>(1<<len);len++);
FFT(a,len,1);
for(int i=0,ci,Len = 1<<len;i<Len;i++)
{
ci = (Len - i) & (Len - 1);
cplx A = (a[i] + a[ci].conj())*cplx(0.5,0) , B = (a[i] - a[ci].conj())*cplx(0,-0.5);
b[i] = A * B;
}
FFT(b,len,-1);
for(int i=0;i<n+m-1;i++) printf("%d ",int(b[i].r+0.5));
}
预处理单位元(精度高):
#include<bits/stdc++.h>
#define maxn 300005
using namespace std;
const double Pi = 3.1415926535897932384626433832795;
struct cplx
{
double r,i;
cplx(double r=0,double i=0):r(r),i(i){
}
cplx operator +(const cplx &B)const{
return cplx(r+B.r,i+B.i); }
cplx operator -(const cplx &B)const{
return cplx(r-B.r,i-B.i); }
cplx operator *(const cplx &B)const{
return cplx(r*B.r-i*B.i,i*B.r+r*B.i); }
cplx conj()const{
return cplx(r,-i); }
}w[maxn],A[maxn],B[maxn];
int r[maxn];
inline void FFT(cplx A[maxn],int lgn,int tp)
{
int n = 1<<lgn;
for(int i=0;i<n;i++) w[i]=cplx(cos(i*Pi/n),sin(i*Pi/n));
for(int i=0;i<n;i++) r[i] = (r[i>>1]>>1)|((i&1)<<(lgn-1));
for(int i=0;i<n;i++) if(i < r[i]) swap(A[i] , A[r[i]]);
for(int L=2;L<=n;L<<=1)
for(int st=0,l=L>>1;st<n;st+=L)
for(int k=0,lc=0,inc=n/l;k<l;k++,lc+=inc)
{
cplx tmp = (tp==1 ? w[lc] : w[lc].conj()) * A[st+k+l];
A[st+k+l]=A[st+k]-tmp,A[st+k]=A[st+k]+tmp;
}
if(tp==-1) for(int i=0;i<n;i++) A[i].r/=n,A[i].i/=n;
}
int main()
{
int n,m;
scanf("%d%d",&n,&m);
for(int i=0;i<=n;i++) scanf("%lf",&A[i].r);
for(int i=0;i<=m;i++) scanf("%lf",&A[i].i);
int lgn=0;for(;n+m>=(1<<lgn);lgn++);
FFT(A,lgn,1);
for(int i=0,len=1<<lgn;i<len;i++)
{
cplx u=A[i],v=A[(len-1)&(len-i)].conj();
B[i]=(u+v)*(u-v)*cplx(0,-0.25);
}
FFT(B,lgn,-1);
for(int i=0;i<n+m;i++) printf("%d ",int(round(B[i].r)));
printf("%d\n",(int)round(B[n+m].r));
}
U P D : 1.0 K B F F T \mathrm {UPD :1.0 KB\ FFT} UPD:1.0KB FFT
#include<bits/stdc++.h>
#define maxn 300005
#define cp complex<double>
#define Pi 3.1415926535897932384626433832795
#define rep(i,j,k) for(int i=(j);i<=(k);i++)
#define per(i,j,k) for(int i=(j);i>=(k);i--)
using namespace std;
int Wl,lg[maxn],r[maxn];
cp W[maxn];
void init(int n){
for(Wl=1;n>=2*Wl;Wl<<=1);
rep(i,0,Wl<<1) W[i]=exp(cp(0,i*Pi/Wl)),(i>1)&&(lg[i]=lg[i>>1]+1);
}
void FFT(cp *A,int n,int tp){
rep(i,1,n-1) (i<(r[i]=(r[i>>1]>>1)|((i&1)<<(lg[n]-1))))&&(swap(A[i],A[r[i]]),0);cp t;
for(int L=1,B=Wl;L<n;L<<=1,B>>=1) for(int s=0;s<n;s+=L<<1) for(int k=s,x=0;k<s+L;k++,x+=B)
t=(tp==1?W[x]:conj(W[x]))*A[k+L],A[k+L]=A[k]-t,A[k]+=t;
if(tp^1) rep(i,0,n-1) A[i]/=n;
}
int n,m;cp A[maxn],B[maxn];
int main(){
scanf("%d%d",&n,&m);
double x;
rep(i,0,n) scanf("%lf",&x),A[i].real(x);
rep(i,0,m) scanf("%lf",&x),B[i].real(x);
init(n+m);
FFT(A,Wl<<1,1),FFT(B,Wl<<1,1);
rep(i,0,(Wl<<1)-1) A[i]*=B[i];
FFT(A,Wl<<1,-1);
rep(i,0,n+m) printf("%d%c",(int)round(A[i].real())," \n"[i==n+m]);
}
MTT
合并DFT详见myy论文。
合并IDFT其实不需要任何技巧因为:
I D F T ( D F T ( A ( i ) ) + i D F T ( B ( i ) ) ) = A ( i ) + i B ( i ) IDFT(DFT(A(i)) + iDFT(B(i))) = A(i) + iB(i) IDFT(DFT(A(i))+iDFT(B(i)))=A(i)+iB(i)
如果觉得慢的话可以将 l o n g d o u b l e \mathrm {long\ double} long double改为 d o u b l e \mathrm {double} double
然后预处理单位根照样可以满足 1 0 5 10^5 105的精度要求
#include<bits/stdc++.h>
#define maxn 300005
#define LL long long
#define M ((1<<15)-1)
#define ld long double
using namespace std;
char cb[1<<15],*cs=cb,*ct=cb;
#define getc() (cs==ct&&(ct=(cs=cb)+fread(cb,1,1<<15,stdin),cs==ct)?0:*cs++)
inline void read(int &res){
char ch;for(;!isdigit(ch=getc()););for(res=ch-'0';isdigit(ch=getc());res=res*10+ch-'0'); }
int p;
const ld Pi = 3.1415926535897932384626433832795;
struct cplx
{
ld r,i;
cplx(ld r=0,ld i=0):r(r),i(i){
}
cplx operator +(const cplx &B)const{
return cplx(r+B.r,i+B.i); }
cplx operator -(const cplx &B)const{
return cplx(r-B.r,i-B.i); }
cplx operator *(const cplx &B)const{
return cplx(r*B.r-i*B.i,i*B.r+B.i*r); }
cplx conj(){
return cplx(r,-i); }
}w[maxn]={
1};
int a[maxn],b[maxn],c[maxn],r[maxn];
inline void FFT(cplx A[maxn],int lgn,int tp)
{
int n = 1<<lgn;
for(int i=1;i<n;i++) r[i] = (r[i>>1]>>1)|((i&1)<<(lgn-1));
for(int i=1;i<n;i++) if(i<r[i])swap(A[i],A[r[i]]);
for(int L=2;L<=n;L<<=1)
{
int l=L>>1;w[1]=cplx(cos(Pi/l),sin(Pi/l)*tp);
for(int i=2;i<l;i++) w[i] = w[i-1] * w[1];
for(int st=0;st<n;st+=L)
for(int k=0;k<l;k++)
{
cplx tmp = w[k] * A[st+k+l];
A[st+k+l] = A[st+k]-tmp , A[st+k] = A[st+k] + tmp;
}
}
if(tp == -1) for(int i=0;i<n;i++) A[i].r/=n,A[i].i/=n;
}
cplx s[4][maxn];
inline void mul(int a[maxn],int b[maxn],int lgn,int c[maxn])
{
int n = 1<<lgn;
for(int i=0;i<n;i++) s[0][i] = cplx(a[i]>>15,b[i]>>15) , s[1][i] = cplx(a[i]&M,b[i]&M);
FFT(s[0],lgn,1),FFT(s[1],lgn,1);
for(int i=0;i