什么是FFT?
快速傅里叶变换,OI中主要用于多项式乘法。时间复杂度为O(nlogn)
下面是一些前置知识。
单位复根
n次单位复根是满足 wn w n =1的复数 w w 。
称为主n次单位根,所有的其他n次单位根都是 wn w n 的幂。对于k=0,1,…,n-1,这些根是 e2πik/n e 2 π i k / n 。他们均匀的分布在以复平面的原点为圆心的单位圆上,其中复数幂定义为 eui=cos(u)+sin(u)i e u i = c o s ( u ) + s i n ( u ) i 。
重要结论
1. n个n次单位复根 w0n,w1n,w2n...wn−1n w n 0 , w n 1 , w n 2 . . . w n n − 1 在乘法意义下形成一个群,即 wjn∗wkn=w(j+k)%nn w n j ∗ w n k = w n ( j + k ) % n
2. 对任何整数n≥0,k≥0,d>0,有 wdkdn=wkn w d n d k = w n k
3. 如果n>0为偶数,n个n次单位复根的平方的集合等于n/2个n/2次单位复根的集合。多项式的系数表示法
对于一个次数界为n的多项式 A(x)=∑n−1j=0ajxj A ( x ) = ∑ j = 0 n − 1 a j x j 则 (a0,a1,a2,...,an−1) ( a 0 , a 1 , a 2 , . . . , a n − 1 ) 即是系数表示。点值表示法
对于一个次数界为n的多项式的点值表示法,就是由n个x以及对应的y=A(x)组成的集合:
((x0,y0),(x1,y1),(x2,y2)...(xn−1,yn−1)) ( ( x 0 , y 0 ) , ( x 1 , y 1 ) , ( x 2 , y 2 ) . . . ( x n − 1 , y n − 1 ) )
显然这样表示出的多项式对应的系数表示是唯一的,就相当于n个方程n个未知数求解。
点值表示的好处在于如果两个多项式
xi
x
i
的取值相同,那么多项式乘法就变能O(n)实现,即把对应的y相乘。如果我们能做到在两种表示法之间快速转换就能加速乘法过程了。
我们需要巧妙的选择x的取值来加速求点值的过程,没错就是用单位复根。
现在我们要求一个次数界为n的给定系数表示的多项式在
x=w0n,w1n,w2n...wn−1n
x
=
w
n
0
,
w
n
1
,
w
n
2
.
.
.
w
n
n
−
1
的点值。
FFT实际上就是用了分治的思想,把一个次数界为n的多项式A(x)分成这样两个:
A[0](x)=a0+a2x1+a4x2+a6x3...+an−2xn/2−1
A
[
0
]
(
x
)
=
a
0
+
a
2
x
1
+
a
4
x
2
+
a
6
x
3
.
.
.
+
a
n
−
2
x
n
/
2
−
1
A[1](x)=a1+a3x1+a5x2+a7x3...+an−1xn/2−1
A
[
1
]
(
x
)
=
a
1
+
a
3
x
1
+
a
5
x
2
+
a
7
x
3
.
.
.
+
a
n
−
1
x
n
/
2
−
1
于是有
A(x)=A[0](x2)+xA[1](x2)
A
(
x
)
=
A
[
0
]
(
x
2
)
+
x
A
[
1
]
(
x
2
)
可得:
(k=0,1,2,...,n/2−1)
(
k
=
0
,
1
,
2
,
.
.
.
,
n
/
2
−
1
)
A(wkn)=A[0](w2kn)+wknA[1](w2kn)
A
(
w
n
k
)
=
A
[
0
]
(
w
n
2
k
)
+
w
n
k
A
[
1
]
(
w
n
2
k
)
=A[0](wkn/2)+wknA[1](wkn/2)
=
A
[
0
]
(
w
n
/
2
k
)
+
w
n
k
A
[
1
]
(
w
n
/
2
k
)
A(wk+n/2n)=A[0](w2(k+n/2)n)+wk+n/2nA[1](w2(k+n/2)n)
A
(
w
n
k
+
n
/
2
)
=
A
[
0
]
(
w
n
2
(
k
+
n
/
2
)
)
+
w
n
k
+
n
/
2
A
[
1
]
(
w
n
2
(
k
+
n
/
2
)
)
=A[0](wk+n/2n/2)+wk+n/2nA[1](wk+n/2n/2)
=
A
[
0
]
(
w
n
/
2
k
+
n
/
2
)
+
w
n
k
+
n
/
2
A
[
1
]
(
w
n
/
2
k
+
n
/
2
)
=A[0](wkn/2)−wknA[1](wkn/2)
=
A
[
0
]
(
w
n
/
2
k
)
−
w
n
k
A
[
1
]
(
w
n
/
2
k
)
也就是说只要求出
A[0]
A
[
0
]
和
A[1]
A
[
1
]
的
x=w0n,w1n,w2n...wn/2−1n
x
=
w
n
0
,
w
n
1
,
w
n
2
.
.
.
w
n
n
/
2
−
1
的点值即可。这是个和原问题相同规模的减小一半的子问题,递归求解即可。
由于递归实现常数较大,一般都是非递归实现的。非递归需要知道当递归到最后一层时原本的系数
ai
a
i
各自所处的位置。
考虑到每次把系数分成两半时是按奇偶来分的,可以看做依次考察其下标二进制的低位到高位,1往右走,0往左走。最后的位置即是递归到最后一层所处的位置。有一种巧妙的方法就是把0~n-1全部二进制翻转后排序,这样得到的顺序一定是最后的位置。非常妙的想法。
现在已经解决了从系数表示法转到点值表示法的过程,那么逆向操作如何实现呢?直接贴结论好了
aj=1/n∑n−1k=0ykw−kjn
a
j
=
1
/
n
∑
k
=
0
n
−
1
y
k
w
n
−
k
j
。(懒得写了,打公式太烦)可以发现这时只需要稍微改一下之前的算法即可。就是把
wn
w
n
改为
w−1n
w
n
−
1
,最后算出的答案全部除以n。
模板如下:(HDU 1402 高精乘)
#include<cstdio>
#include<cmath>
#include<cstring>
#include<algorithm>
using namespace std;
const double PI=acos(-1.0);
const int maxn=(1<<18)+5;
struct E{
double real,imag;
E(double real=0,double imag=0):real(real),imag(imag){}
void operator /=(double x){ real/=x; imag/=x; }
};
E operator + (E &a,E &b){ return E(a.real+b.real,a.imag+b.imag); }
E operator - (E &a,E &b){ return E(a.real-b.real,a.imag-b.imag); }
E operator * (E &a,E &b){ return E(a.real*b.real-a.imag*b.imag,a.imag*b.real+a.real*b.imag); }
int rev[maxn];
void build_rev(int n){
rev[0]=0; int log2n=log2(n);
for(int i=1;i<=n-1;i++) rev[i]=(rev[i>>1]>>1)|((i&1)<<(log2n-1));
}
const double eps=1e-10;
int dcmp(double x){
if(fabs(x)<eps) return 0;
return x>0?1:-1;
}
int b[maxn];
struct Complex_Line{
int n; E a[maxn];
void mem(){ for(int i=0;i<=n-1;i++) a[i].real=a[i].imag=0; n=0; }
void FFT(int k){
build_rev(n); for (int i=0;i<=n-1;i++) if (i<rev[i]) swap(a[i],a[rev[i]]);
for(int m=2;m<=n;m<<=1){
E wm(cos(2*PI/m),k*sin(2*PI/m));
for(int k=0;k<n-1;k+=m){
E w(1,0),t0,t1;
for(int j=0;j<=m/2-1;j++,w=w*wm) t0=a[k+j], t1=w*a[k+j+m/2], a[k+j]=t0+t1, a[k+j+m/2]=t0-t1;
}
}
if(k==-1) for(int i=0;i<=n;i++) a[i]/=n;
}
void write(){
memset(b,0,sizeof b);
for(int i=0;i<=n-1;i++) b[i]+=(int)(a[i].real+0.5), b[i+1]+=b[i]/10, b[i]%=10;
int m=n; while(m>1&&!b[m-1]) m--;
for(int i=m-1;i>=0;i--) printf("%d",b[i]);
}
} A,B,C;
char s[maxn];
int main(){
freopen("hdoj1402.in","r",stdin);
freopen("hdoj1402.out","w",stdout);
for(scanf("%s",s);s[0]!='\000';s[0]='\000',scanf("%s",s)){
A.mem(); A.n=strlen(s); for(int i=0;i<=A.n-1;i++) A.a[i].real=s[A.n-1-i]-'0';
scanf("%s",s);
B.mem(); B.n=strlen(s); for(int i=0;i<=B.n-1;i++) B.a[i].real=s[B.n-1-i]-'0';
C.n=1; while(C.n<A.n+B.n) C.n<<=1;
A.n=C.n; A.FFT(1); B.n=C.n; B.FFT(1);
for(int i=0;i<=C.n-1;i++) C.a[i]=A.a[i]*B.a[i];
C.FFT(-1); C.write(); printf("\n");
}
return 0;
}