其实很早就想学习一下 FFT 了,不过比赛的时候和 FFT 有关的题目我都交给了队友,所以也没什么动力学- -今天看到 Gilbert Strang 的线性代数书中竟然也介绍了 FFT,就顺便学习了一下。
FFT 解决的问题
我们知道,一个 $n-1$ 次多项式可以这样表示:$$\sum_{k=0}^{n-1}a_kx^k$$ 这种表示方法称为多项式的“系数表示法”。容易看出,只要确定了 $a_0,a_1,\dots,a_{n-1}$ 的值,就能唯一确定一个 $n-1$ 次多项式。
事实上,一个 $n-1$ 次多项式还可以用“点值表示法”进行表示。我们把多项式看成函数 $f(x)$,只要给出 $n$ 个点 $(x_1,f(x_1)),(x_2,f(x_2)),\dots,(x_n,f(x_n))$,且这 $n$ 个点的 x 值互不相同,我们也能唯一确定一个 $n-1$ 次多项式。确定这个多项式的过程称为“插值”。
为什么 $n$ 个点可以唯一确定多项式呢?我们可以把给出的 $n$ 个点看作一个方程组,用矩阵表示为 $$\begin{bmatrix}1 & x_1 & x_1^2 & \dots & x_1^n \\ 1 & x_2 & x_2^2 & \dots & x_2^n \\ \vdots & \vdots & \vdots & \vdots & \vdots \\ 1 & x_n & x_n^2 & \dots & x_n^n\end{bmatrix} \begin{bmatrix}a_0 \\ a_1 \\ \vdots \\ a_{n-1}\end{bmatrix} = \begin{bmatrix}f(x_1) \\ f(x_2) \\ \vdots \\ f(x_n)\end{bmatrix}$$ 简记为 $Xa = f$,根据线性代数的知识我们知道,矩阵 $X$ 是范德蒙矩阵,其行列式为 $$\prod_{i\ne j}(x_i-x_j)$$ 显然,只要 x 的值各不相等,该矩阵的行列式就不为 0,说明该矩阵可逆,则我们能唯一确定多项式的系数向量为 $a = X^{-1}f$。
我们需要解决的问题是:给出一个多项式的系数表示或点值表示,如何用一种表示方法推出另一种表示方法。
如果我们使用朴素的方法,在给出系数表示时,我们随便取 $n$ 个不同的 x 值,分别计算出多项式的取值来获得点值表示,显然复杂度为 $O(n^2)$;在给出点值表示时,我们使用拉格朗日插值法获得系数表示,复杂度也为 $O(n^2)$。有没有更快的解决方法呢?
FFT 就能够更快地帮我们在多项式的系数表示和点值表示之间进行转换,它的复杂度是 $O(n\text{log}n)$。接下来我们就来推导并介绍 FFT 是如何加快这个过程的。
单位根
为了推导 FFT,我们首先介绍单位根。
我们称 $x^n=1$ 的所有复数解为单位根,这 $n$ 个单位根是 $e^{2\pi i / n} = (e^{2\pi / n})^i = w_n^i\quad i \in \{0,1,\dots,n-1\}$
来验证一下,根据欧拉公式 $e^{\theta i} = \text{cos}\theta + \text{sin}\theta$,我们有 $$(e^{2\pi i/n})^n = e^{2\pi i} = \text{cos}2\pi + \text{sin}2 \pi = 1$$
根据代数基本定理,这个方程有且只有 $n$ 个复数解,这些解就是全部的单位根了。
单位根有一个很奇妙的性质:所有单位根之和为 0。证明如下:$$(1 + w_n + w_n^2 + \dots + w_n^{n-1})(w_n-1) = w_n^n-1 = 1-1 = 0$$ 显然我们有 $$1 + w_n + w_n^2 + \dots + w_n^{n-1} = 0$$ 或 $$w_n = 1$$ 只有在 $n = 1$ 时,才有 $w_n = 1$。所以,对于 $n \ne 1$,单位根之和为 0。
单位根还有另一个(比较容易观察到的)性质:$w_n^2 = e^{4\pi i / n} = e^{2\pi i / (n/2)} = w_{n/2}$,这个性质是 FFT 的核心。
傅里叶矩阵
定义 $n$ 阶傅里叶矩阵如下 $$F_n = \begin{bmatrix}1 & 1 & 1 & \dots & 1 \\ 1 & w_n^{1\times 1} & w_n^{1\times 2} & \dots & w_n^{1 \times (n-1)}\\ 1 & w_n^{2\times 1} & w_n^{2\times 2} & \dots & w_n^{2\times (n-1)}\\ \vdots & \vdots & \vdots & \vdots & \vdots\\ 1 & w_n^{(n-1)\times 1} & w_n^{(n-1)\times 2} & \dots & w_n^{(n-1)\times (n-1)}\end{bmatrix}$$ 这个矩阵也有一个非常好的性质:它的逆矩阵非常容易运算:$$F_n^{-1} = \frac{1}{n}\begin{bmatrix}1 & 1 & 1 & \dots & 1 \\ 1 & w_n^{-1\times 1} & w_n^{-1\times 2} & \dots & w_n^{-1 \times (n-1)}\\ 1 & w_n^{-2\times 1} & w_n^{-2\times 2} & \dots & w_n^{-2\times (n-1)}\\ \vdots & \vdots & \vdots & \vdots & \vdots\\ 1 & w_n^{-(n-1)\times 1} & w_n^{-(n-1)\times 2} & \dots & w_n^{-(n-1)\times (n-1)}\end{bmatrix}$$ 我们来验证一下:
$F_n$ 的第 $a$ 行与 $F_n^{-1}$ 的第 $b$ 列的点积为 $$\frac{1}{n}\sum_{k=0}^{n-1}w_n^{a\times k}w_n^{-k\times b} = \frac{1}{n}\sum_{k=0}^{n-1}w_n^{k(a-b)}$$ 若 $a = b$,显然点积为 1;若 $a \ne b$,我们记 $W = w_n^{a-b}$类似于上一节的推导,我们有 $$\frac{1}{n}(W-1)\sum_{k=0}^{n-1}W^k = W^n-1 = e^{2\pi(a-b)}-1 = 0$$ 而由 $a \ne b$ 我们有 $W \ne 1$,则点积为 0。
快速离散傅里叶变换
给出一个 $n$ 维向量 $c$,如果我们用朴素的方法计算 $y = F_nc$,需要进行一个复杂度为 $O(n^2)$ 的乘法。而 FFT 可以让这个过程加快至 $O(n\text{log}n)$。
不失一般性地,我们假设 $n = 2^l$(如果 $n$ 不是 2 的幂,我们可以通过补 0 把多项式的系数个数补充到 2 的幂)。FFT 的过程如下。
设 $c = \begin{bmatrix} c_0 & c_1 & c_2 & \dots & c_{n-1} \end{bmatrix}^T$,我们构造两个 $\frac{n}{2}$ 维的向量 $c' = \begin{bmatrix} c_0 & c_2 & c_4 & \dots & c_{n-2} \end{bmatrix}^T$ 与 $c'' = \begin{bmatrix} c_1 & c_3 & c_5 & \dots & c_{n-1} \end{bmatrix}^T$。这样,我们就能分别计算出 $y' = F_{n/2}c'$ 与 $y'' = F_{n/2}c''$。假设我们能通过某种方式,从 $y'$ 与 $y''$ 中获得 $y$,我们就发现了一个递归的过程。这样,我们就能一步一步把问题分解成最小规模的 $y = F_1c = c$,再通过“某种方式”一层一层合并出原来的解。
这个“某种过程”的式子如下(设 $y_j$ 表示向量 $y$ 的第 $j$ 项元素):$$y_j = y_j' + w_n^jy_j'' \quad j = 0\dots n-1$$ $$y_{j+n/2} = y_j' - w_n^jy_j'' \quad j = 0\dots n-1$$ 显然,这个合并过程的复杂度是线性的,所以总的复杂度为 $O(n\text{log}n)$。
接下来我们证明这个“递归式”。首先是第一个式子:$$y_j = \sum_{k=0}^{n-1}w_n^{jk}c_k = \sum_{k=0}^{n/2-1}w_n^{j\times 2k}c_{2k} + \sum_{k=0}^{n/2-1}w_n^{j\times (2k+1)}c_{2k+1}$$ $$=\sum_{k=0}^{n/2-1}w_{n/2}^{jk}c_k' + w_n^j\sum_{k=0}^{n/2-1}w_{n/2}^{jk}c_k'' = y_j' + w_n^jy_j''$$ 接下来证明第二个式子: $$y_{j+n/2} = \sum_{k=0}^{n-1}w_n^{(j+n/2)k}c_k = \sum_{k=0}^{n/2-1}w_n^{(j+n/2)\times 2k}c_{2k} + \sum_{k=0}^{n/2-1}w_n^{(j+n/2)\times(2k+1)}c_{2k+1}$$ $$=w_n^{nk}\sum_{k=0}^{n/2-1}w_{n/2}^{jk}c_k' + w_n^{nk}w_n^{n/2}w_n^j\sum_{k=0}^{n/2-1}w_{n/2}^{jk}c_k''$$ 注意到 $w_n^{nk} = 1$,$w_n^{n/2} = e^{\pi i} = -1$,我们有 $$y_{j+n/2}=\sum_{k=0}^{n/2-1}w_{n/2}^{jk}c_k' - w_n^j\sum_{k=0}^{n/2-1}w_{n/2}^{jk}c_k'' = y_j'-w_n^jy_j''$$ 这样我们就能在 $O(n\text{log}n)$ 的时间内,利用这种递归的方式快速计算出 $y=F_nc$ 的值。
多项式表示方法之间的转换
对比一下 FFT 的矩阵表达式 $F_nc=y$ 与多项式插值的矩阵表达式 $Xa = f$,如果我们令 $X = F_n$,$a = c$,$f = y$,就建立起了两者之间的关系。FFT 实际上是通过选择单位根作为插值时的 x 值,再使用快速离散傅里叶变换,就能快速算出 x 值为 $n$ 个单位根时,多项式的 $n$ 个取值,以此将系数表示转化为了点值表示。
而要将点值表示转化回系数表示也非常简单,只要求 $c = F_n^{-1}y$ 即可完成,$F_n^{-1}$ 的形式已经在之前的小节中给出,与 $F_n$ 非常相似,求解方法也与 $F_n$ 相同,这里不再赘述。
FFT 求多项式乘法
如果我们有两个多项式 $A = \displaystyle \sum_{k=0}^{n-1}a_kx^k$ 与 $B = \displaystyle \sum_{k=0}^{m-1}b_kx^k$,我们如何快速求出两者的乘积多项式 $C = \displaystyle \sum_{k=0}^{n+m-2}c_kx^k$(其中 $c_k = \displaystyle \sum_{t=0}^k a_tb_{k-t}$)呢?
我们先通过两个 FFT($y=F_nc$)将多项式 $A$ 与 $B$ 由系数表示转化为点值表示 $(x_k,A(x_k))$ 与 $(x_k, B(x_k))$。由于 $C = AB$,显然 $C$ 的点值表示为 $(x_k,A(x_k)B(x_k))$,那么我们可以用 $O(n)$ 的时间通过 $A$ 与 $B$ 的点值表示获得 $C$ 的点值表示。再通过一个 FFT($c=F_n^{-1}y$)将 $C$ 的点值表示变回系数表示即可。
不难注意到 $c_k$ 的值是由卷积运算得来的,所以 FFT 也可以用于加快卷积运算。
最后附上 FFT 的一个非递归实现(模板题:LOJ108)
1 #include <bits/stdc++.h> 2 #define MAXP 262144 3 #define PI 3.1415926535897 4 using namespace std; 5 6 int n,m; 7 complex<double> A[300010],B[300010],C[300010]; 8 9 void change(complex<double> *a,int len) 10 { 11 int i,j,k; 12 for(i=1,j=len>>1;i<len;i++) 13 { 14 if(i<j) swap(a[i],a[j]); 15 for(k=len>>1;j&k;k>>=1) j &= ~k; 16 j |= k; 17 } 18 } 19 20 void fft(complex<double> *a,int len,bool inv) 21 { 22 int i,j,k,lim; 23 complex<double> w,t,u,v; 24 change(a,len); 25 26 for(i=2;i<=len;i<<=1) 27 { 28 w = complex<double>(cos(2*PI/i),sin(2*PI/i*(inv?-1:1))); 29 for(j=0;j<len;j+=i) 30 { 31 t = 1; lim = j + (i>>1); 32 for(k=j;k<lim;k++) 33 { 34 u = a[k]; v = t * a[k+(i>>1)]; 35 a[k] = u + v; a[k+(i>>1)] = u - v; 36 t *= w; 37 } 38 } 39 } 40 41 if(inv) for(i=0;i<len;i++) a[i] /= len; 42 } 43 44 int main() 45 { 46 int i,x; 47 48 scanf("%d%d",&n,&m); 49 for(i=0;i<=n;i++) scanf("%d",&x), A[i] = x; 50 for(i=0;i<=m;i++) scanf("%d",&x), B[i] = x; 51 52 fft(A,MAXP,false); fft(B,MAXP,false); 53 for(i=0;i<MAXP;i++) C[i] = A[i]*B[i]; 54 fft(C,MAXP,true); 55 56 for(i=0;i<=n+m;i++) printf("%.0f%c",C[i].real(),"\n "[i<n+m]); 57 return 0; 58 }
每个点大约要跑 150ms,感觉常数略大...代码里使用了 STL 自带的 complex,据说速度比较慢,不过我后来自己定义了一个 complex,并没有发现速度快多少,也许有什么更加科学的实现方式吧...