基础知识
离散傅里叶(DFT)和逆变换(IDFT)时间复杂度都是
O
(
n
2
)
O(n^2)
O(n2)
快速傅里叶(FFT)时间复杂度为
O
(
n
l
o
g
2
n
)
O(nlog_2n)
O(nlog2n)
推导过程:
设
A
(
x
)
=
a
0
+
a
1
x
+
a
2
x
2
+
⋯
+
a
n
x
n
A(x)=a_0+a_1x+a_2x^2+\dots+a_{n}x^n
A(x)=a0+a1x+a2x2+⋯+anxn,假设n为偶数,根据奇偶性分为:
A
(
x
)
=
(
a
0
+
a
2
x
2
+
⋯
+
a
n
x
n
)
+
(
a
1
x
+
a
3
x
3
+
⋯
+
a
n
−
1
x
n
−
1
A(x)=(a_0+a_2x^2+\dots+a_nx^n)+(a_1x+a_3x^3+\dots+a_{n-1}x^{n-1}
A(x)=(a0+a2x2+⋯+anxn)+(a1x+a3x3+⋯+an−1xn−1
设多项式
A
1
(
x
)
A_1(x)
A1(x)和
A
2
(
x
)
A_2(x)
A2(x)
A
1
(
x
)
=
a
0
+
a
2
x
+
⋯
+
a
n
x
n
2
A_1(x)=a_0+a_2x+\dots+a_nx^{\frac n2}
A1(x)=a0+a2x+⋯+anx2n
A
2
(
x
)
=
a
1
+
a
3
x
+
⋯
+
a
n
−
1
x
n
2
−
1
A_2(x)=a_1+a_3x+\dots+a_{n-1}x^{\frac {n}2-1}
A2(x)=a1+a3x+⋯+an−1x2n−1
因此可以得到:
A
(
x
)
=
A
1
(
x
2
)
+
x
A
2
(
x
2
)
A(x)=A_1(x^2)+xA_2(x^2)
A(x)=A1(x2)+xA2(x2)
设
k
<
n
2
k<\frac n2
k<2n,把
x
=
w
n
k
x=w_n^k
x=wnk代入:
A
(
w
n
k
)
=
A
1
(
w
n
2
k
)
+
w
n
k
A
2
(
w
n
2
k
)
A(w_n^k)=A_1(w_n^{2k})+w_n^kA_2(w_n^{2k})
A(wnk)=A1(wn2k)+wnkA2(wn2k)
A ( w n k ) = A 1 ( w n 2 k ) + w n k A 2 ( w n 2 k ) A(w_n^k)=A_1(w_{\frac n2}^{k})+w_n^kA_2(w_{\frac n2}^{k}) A(wnk)=A1(w2nk)+wnkA2(w2nk)
把
x
=
w
n
k
+
n
2
x=w_n^{k+\frac n2}
x=wnk+2n代入可得:
A
(
w
n
k
+
n
2
)
=
A
1
(
w
n
2
k
+
n
)
+
w
n
k
+
n
2
A
2
(
w
n
2
k
+
n
)
A(w_n^{k+\frac n2})=A_1(w_n^{2k+n}) +w_n^{k+\frac n2}A_2(w_n^{2k+n})
A(wnk+2n)=A1(wn2k+n)+wnk+2nA2(wn2k+n)
A ( w n k + n 2 ) = A 1 ( w n 2 k ) − w n k A 2 ( w n 2 k ) A(w_n^{k+\frac n2})=A_1(w_{\frac n2}^{k}) -w_n^{k}A_2(w_{\frac n2}^{k}) A(wnk+2n)=A1(w2nk)−wnkA2(w2nk)
观察化简后的两式,可知只要知道了 A 1 ( w n 2 k ) A_1(w_{\frac n2}^{k}) A1(w2nk)和 A 2 ( w n 2 k ) A_2(w_{\frac n2}^{k}) A2(w2nk),就可以求得 A ( w n k + n 2 ) A(w_n^{k+\frac n2}) A(wnk+2n)和 A ( w n k ) A(w_n^k) A(wnk),也就是通过下层的两个值,求得当前层的两个值
FFT进行多项式乘法的步骤
1、对两个多项式补0
2、用FFT计算两个多项式A、B的点值表示法
3、得到乘积多项式C的点值表示法
4、用FFT通过IDFT计算多项式C的系数表示
P3803 【模板】多项式乘法(FFT)
题意:给定多项式A和B,计算两式相乘之后的多项式系数
细节:n和m的范围是1e6,相加是2e6,那么对多项式补0,最多是4e6。所以开4e6的空间就可以了
带注释代码:
#include <iostream>
#include <algorithm>
#include <cstdio>
#include <cmath>
#define rep(i,a,b) for (int i=a; i<=b; ++i)
using namespace std;
const int maxn=4e6+10,INF=0x3f3f3f3f;
const double PI=acos(-1.0);
struct Complex
{
double x,y;
Complex(double x1=0,double y1=0)
{
x=x1,y=y1;
}
};
Complex operator+(Complex a,Complex b)
{
return {a.x+b.x,a.y+b.y};
}
Complex operator-(Complex a,Complex b)
{
return {a.x-b.x,a.y-b.y};
}
Complex operator*(Complex a,Complex b)
{
return {a.x*b.x-a.y*b.y,a.x*b.y+b.x*a.y};
}
Complex a[maxn],b[maxn];
int total=1,digit=0;
int rev[2*maxn];
void fft(Complex *A,int f)
{
//构造最终序列,我们只求出了i的翻转后的rev[i],但是并没有对实际的数组做改造
//判断i<rev[i]可以避免对两个元素交换两次
for(int i=0;i<total;++i)
if(i<rev[i])
swap(A[i],A[rev[i]]);
//构造完序列后,枚举中点
for(int mid=1;mid<total;mid<<=1)
{
//每次都乘上一个k
Complex Wn={cos(PI/mid),f*sin(PI/mid)};
int len=mid<<1;
//遍历每一段长度
for(int p=0;p<total;p+=len)
{
//起点是(1,0)
Complex Wk={1,0};
//枚举从[0,mid/2) 枚举k
for(int k=0;k<mid;++k)
{
//蝴蝶效应,记录下层的两个值,然后修改当前层的两个值
Complex x=A[p+k],y=Wk*A[p+k+mid];
A[p+k]=x+y;
A[p+k+mid]=x-y;
Wk=Wk*Wn;
}
}
}
}
int main()
{
int n,m;
scanf("%d%d",&n,&m);
rep(i,0,n)
scanf("%lf",&a[i].x);
rep(i,0,m)
scanf("%lf",&b[i].x);
while(total<=n+m)
total<<=1,digit++;
//i和i/2的关系是:i右移一位得到i/2,那么翻转够来之后,就是i/2右移一位得到i
//并且要考虑i的最低位,翻转之后要加在最高位上
for(int i=0;i<total;++i)
rev[i]=(rev[i>>1]>>1)|(i&1)<<(digit-1);
//求出多项式a、b的点值表示
fft(a,1);
fft(b,1);
//通过点值表示合并以后的多项式c
for(int i=0;i<=total;++i)
a[i]=a[i]*b[i];
//将c的点值表示,重新还原为系数表示
fft(a,-1);
for(int i=0;i<=n+m;++i)
printf("%d ",(int)(a[i].x/total+0.5));
return 0;
}
不带注释代码:
#include <iostream>
#include <algorithm>
#include <cstdio>
#include <cmath>
#define rep(i,a,b) for (int i=a; i<=b; ++i)
using namespace std;
const int maxn=4e6+10,INF=0x3f3f3f3f;
const double PI=acos(-1.0);
struct Complex
{
double x,y;
Complex(double x1=0,double y1=0)
{
x=x1,y=y1;
}
};
Complex operator+(Complex a,Complex b)
{
return {a.x+b.x,a.y+b.y};
}
Complex operator-(Complex a,Complex b)
{
return {a.x-b.x,a.y-b.y};
}
Complex operator*(Complex a,Complex b)
{
return {a.x*b.x-a.y*b.y,a.x*b.y+b.x*a.y};
}
Complex a[maxn],b[maxn];
int total=1,digit=0;
int rev[2*maxn];
void fft(Complex *A,int f)
{
for(int i=0;i<total;++i)
if(i<rev[i])
swap(A[i],A[rev[i]]);
for(int mid=1;mid<total;mid<<=1)
{
Complex Wn={cos(PI/mid),f*sin(PI/mid)};
int len=mid<<1;
for(int p=0;p<total;p+=len)
{
Complex Wk={1,0};
for(int k=0;k<mid;++k)
{
Complex x=A[p+k],y=Wk*A[p+k+mid];
A[p+k]=x+y;
A[p+k+mid]=x-y;
Wk=Wk*Wn;
}
}
}
}
int main()
{
int n,m;
scanf("%d%d",&n,&m);
rep(i,0,n)
scanf("%lf",&a[i].x);
rep(i,0,m)
scanf("%lf",&b[i].x);
while(total<=n+m)
total<<=1,digit++;
for(int i=0;i<total;++i)
rev[i]=(rev[i>>1]>>1)|(i&1)<<(digit-1);
fft(a,1);
fft(b,1);
for(int i=0;i<=total;++i)
a[i]=a[i]*b[i];
fft(a,-1);
for(int i=0;i<=n+m;++i)
printf("%d ",(int)(a[i].x/total+0.5));
return 0;
}
模板写法
#include <iostream>
#include <algorithm>
#include <cstdio>
#include <queue>
#include <vector>
#include <set>
#include <map>
#include <cstring>
#include <string>
#include <cmath>
#define rep(i,a,b) for (int i=a; i<=b; ++i)
#define per(i,b,a) for (int i=b; i>=a; --i)
#define mes(a,b) memset(a,b,sizeof(a))
#define mp make_pair
#define ll long long
#define pb push_back
#define pii pair<int,int>
#define pll pair<ll,ll>
#define ls (rt<<1)
#define rs ((rt<<1)|1)
#define isZero(d) (abs(d) < 1e-8)
using namespace std;
const int maxn=1e6+5,INF=0x3f3f3f3f;
const int mod=1e9+7;
const double PI=acos(-1.0);
struct Complex
{
double x,y;
Complex(double x1=0.0,double y1=0.0)
{
x=x1,y=y1;
}
};
Complex operator+(Complex a,Complex b)
{
return {a.x+b.x,a.y+b.y};
}
Complex operator-(Complex a,Complex b)
{
return {a.x-b.x,a.y-b.y};
}
Complex operator*(Complex a,Complex b)
{
return {a.x*b.x-a.y*b.y,a.x*b.y+a.y*b.x};
}
struct FFT
{
int total,digit,rev[maxn<<2];
Complex a[maxn<<2],b[maxn<<2];
void init(int len)
{
total=1,digit=0;
while(total<=len)
total<<=1,digit++;
for(int i=0;i<total;++i)
rev[i]=(rev[i>>1]>>1)|(i&1)<<(digit-1);
}
void fft(Complex *A,int f)
{
for(int i=0;i<total;++i)
if(i<rev[i])
swap(A[i],A[rev[i]]);
for(int mid=1;mid<total;mid<<=1)
{
Complex Wn={cos(PI/mid),f*sin(PI/mid)};
int len=mid*2;
for(int p=0;p<total;p+=len)
{
Complex Wk={1,0};
for(int k=0;k<mid;++k)
{
Complex x=A[p+k],y=Wk*A[p+k+mid];
A[p+k]=x+y;
A[p+k+mid]=x-y;
Wk=Wk*Wn;
}
}
}
}
void cal()
{
fft(a,1),fft(b,1);
for(int i=0;i<total;++i)
a[i]=a[i]*b[i];
fft(a,-1);
}
}F;
int main()
{
int n,m;
scanf("%d%d",&n,&m);
for(int i=0;i<=n;++i)
scanf("%lf",&F.a[i].x);
for(int i=0;i<=m;++i)
scanf("%lf",&F.b[i].x);
F.init(n+m);
F.cal();
for(int i=0;i<=n+m;++i)
printf("%d ",(int)(F.a[i].x/F.total+0.5));
return 0;
}
学习博客
https://www.cnblogs.com/RabbitHu/p/FFT.html
https://www.cnblogs.com/RabbitHu/p/FFT.html