OJ:快速傅里叶变换(FFT)学习笔记

FFT 快速傅里叶变换 学习笔记

前言

FFT是一个快速(O(nlogn))求两个多项式乘积的算法。
在阅读本文之前,请先阅读1.本文是对1的解释和补充,主要贡献有两点:

  • 把原文没说清的理论部分补上;
  • 通过递归,把原文的思路重新理了一遍,使理论和代码都更加好理解。

另外,本文的附录里面只放了源代码,且在文中有提到。

理论:IFFT的理论基础

只要搞懂以下5个topic,就可以明白快速傅里叶变换的代码逻辑:

1.多项式的系数表示法和点值表示法
2.n次单位根
3.DFT
4.FFT
5.IFFT

理论部分的前四点构成了FFT的理论基础。 对于这四点,1都有,且很好理解。请读者阅读原文熟悉前四点理论。

下面主要讲一下第5点(IFFT的理论基础)。

FFT本质上解决了一个乘法问题

以下所有记号和1对应,并记 w i = ω n i = e x p ( i ∗ 2 π i / n ) w_i = \omega_n^i = exp(\mathbf{i}*2\pi i/n) wi=ωni=exp(i2πi/n),对应n次单位根。( i \mathbf{i} i是虚数单位。)

记n次单位根的幂矩阵W,系数向量C,值向量V如下:

W = [ W i , j = ( w i ) j ] n ∗ n = [ ( w 0 ) 0 ⋯ ( w 0 ) n − 1 ( w 1 ) 0 ⋯ ( w 1 ) n − 1 ⋮ ⋱ ⋮ ( w n − 1 ) 0 ⋯ ( w n − 1 ) n − 1 ] W=[W_{i,j}=(w_i)^j]_{n*n}= \left[ \begin{aligned} (w_0)^0 & \cdots & (w_0)^{n-1} \\ (w_1)^0 & \cdots & (w_1)^{n-1} \\ \vdots & \ddots & \vdots \\ (w_{n-1})^0 & \cdots & (w_{n-1})^{n-1} \\ \end{aligned} \right] W=[Wi,j=(wi)j]nn=(w0)0(w1)0(wn1)0(w0)n1(w1)n1(wn1)n1

C = [ a 0 a 1 ⋮ a n − 1 ] C = \left[ \begin{aligned} a_0 \\ a_1 \\ \vdots \\ a_{n-1} \end{aligned} \right] C=a0a1an1

V = [ A ( x 0 ) A ( x 1 ) ⋮ A ( x n − 1 ) ] V = \left[ \begin{aligned} A(x_0) \\ A(x_1) \\ \vdots \\ A(x_{n-1}) \end{aligned} \right] V=A(x0)A(x1)A(xn1)

其中 W C = V WC=V WC=V.

FFT本质上解决的问题是已知W和C,计算V的乘法问题,对这个问题给出了一个 O ( n l o g n ) O(nlogn) O(nlogn)的方法。

IFFT可以被表示成一个类似的乘法问题

IFFT可以表示成一个类似的乘法问题,并通过相似的方式解决。这一部分参考了2.

我们首先求W的逆矩阵。记 v i = c o n j ( w i ) = e x p ( − i ∗ 2 π i / n ) v_i = conj(w_i) = exp(-\mathbf{i}*2\pi i/n) vi=conj(wi)=exp(i2πi/n),即 w i w_i wi的共轭。( i \mathbf{i} i是虚数单位。)

记矩阵W’为

W ′ = [ W i , j ′ = ( v i ) j ] n ∗ n = [ ( v 0 ) 0 ⋯ ( v 0 ) n − 1 ( v 1 ) 0 ⋯ ( v 1 ) n − 1 ⋮ ⋱ ⋮ ( v n − 1 ) 0 ⋯ ( v n − 1 ) n − 1 ] W'=[W'_{i,j}=(v_i)^j]_{n*n}= \left[ \begin{aligned} (v_0)^0 & \cdots & (v_0)^{n-1} \\ (v_1)^0 & \cdots & (v_1)^{n-1} \\ \vdots & \ddots & \vdots \\ (v_{n-1})^0 & \cdots & (v_{n-1})^{n-1} \\ \end{aligned} \right] W=[Wi,j=(vi)j]nn=(v0)0(v1)0(vn1)0(v0)n1(v1)n1(vn1)n1

发现 W W ′ = W ′ W = n I WW'=W'W=nI WW=WW=nI(易证),于是W的逆矩阵是 W ′ n \frac{W'}{n} nW,有 W ′ V = n C W'V=nC WV=nC.因此,只要在FFT的实现算法中把n次单位根改成其共轭,就可以从向量V用 O ( n l o g n ) O(nlogn) O(nlogn)时间求解到向量C,即实现IFFT算法。

综上所述,实现IFFT和实现FFT使用的是同一套逻辑,只是使用了不同的n次单位根。

FFT算法的基础实现

以下是FFT算法的基础实现,fft_logic是FFT和IFFT的共同主逻辑,fftifft是给fft_logic套的两个壳。结合注释阅读fft_logic的代码即可。
(完整的程序见附录1,包括测试用的主函数。complex类的使用见3.)

#include <iostream>
#include <cmath>
#include <complex>
using namespace std;

const double pi = 3.14159265358979323846;
const int N = 8; // 多项式的最大支持位数
complex<double> b[N]; // 用来充当临时调整空间的数组

void fft_logic(complex<double> *a, int n, int inv){

    // 参数:
    // 当inv = 1时,a是系数多项式,n是当前数组长度(2的幂次),函数效果是原地变成点值多项式
    // 当inv = -1时,a是点值多项式,n是当前数组长度(2的幂次),函数效果是原地变成系数多项式,但是所得的系数是n倍,需要在包装的函数中进行调整

    if (n == 1) return; // 为什么?因为omega_1^0=1,点值多项式和系数多项式的表示完全一致。

    // 利用B暂存和写回,把a的顺序调整为 a[0] a[2] .. a[n-2] a[1] a[3] .. a[n-1],前后两半

    for(int i = 0; i < n/2; i ++){
        b[i]       = a[i * 2];
        b[i + n/2] = a[i * 2 + 1];
    }
    for(int i = 0; i < n; i ++)
        a[i] = b[i];

    // 分治求A1和A2

    fft_logic(a, n/2, inv);
    fft_logic(a + n/2, n/2, inv);

    // 通过A1和A2,计算A

    double unit_rad = 2 * pi / n; // 单位角幅度值

    for(int i = 0; i < n/2; i ++){
        complex<double> x(cos(i * unit_rad), inv*sin(i * unit_rad)); // x = omega_n^i 
        complex<double> tmp1 = a[i];
        complex<double> tmp2 = x * a[i + n/2];
        a[i]       = tmp1 + tmp2;
        a[i + n/2] = tmp1 - tmp2;
    }

}

void fft(complex<double> *a, int n){
    // 输入系数多项式及其长度,原地转换为点值多项式
    fft_logic(a, n, 1);
}

void ifft(complex<double> *a, int n){
    // 输入点值多项式及其长度,原地转换为系数多项式
    fft_logic(a, n, -1);
    for(int i = 0; i < n; i ++) 
        a[i] /= n;
}

如同原文作者所说,这种方法的缺点就是常数太大需要优化…

FFT的迭代版本:优化常数

基于递归的版本

针对上一节中算法的不足,注意到ffp_logic函数中“每次递归都重排”,我们将其优化成“预先重排后直接按顺序处理”。

优化的思路是,针对原来算法中每一次A1和A2的穿插,其实可以通过提前重排原始数组来代替。拿来原文的一张图来说明:

n=8时,原来的顺序最终被穿插成了0 4 2 6 1 5 3 7,n=4时则是0 2 1 3.这个数列仅由n决定,所以我们可以用一个函数生成这样的穿插数列。

void fft_rearrange_decidesequence_logic(int *rev, int n){

    // 给定数组rev和数组长度n,函数的功能是将所需的顺序写入数组,比如n = 4时将顺序 0 2 1 3 写入数组rev

    if(n == 1){
        rev[0] = 0;
        return;
    }

    // 获得 n/2 时的顺序,暂时放在rev的后半

    fft_rearrange_decidesequence_logic(rev + n/2, n/2);

    // 利用 n/2 时的顺序构造 n 时的顺序

    for(int i = 0; i < n/2; i ++){
        rev[i] = 2 * rev[i + n/2];
        rev[i + n/2] = 2 * rev[i + n/2] + 1;
    }

}

fft_rearrange_decidesequence_logic函数用递归的方式,自然地把穿插的逻辑描述出来,结果得到正确的重排顺序。

void fft_rearrange_logic(complex<double> *a, int n){

    // 按照算法,把a重新排列为可以直接递归向上的顺序

    // 计算bit: 满足pow(2, bit) = n

    int bit = 0;
    while((1 << bit) < n) bit ++;
    
    // rev: 确定a最终rearrange的位置序列

    int* rev = new int[n];
    fft_rearrange_decidesequence_logic(rev, n);

    // 按照rev序列调整a的顺序

    complex<double>* tmp_a = new complex<double>[n];

    for(int i = 0; i < n; i ++) tmp_a[i] = a[rev[i]];
    for(int i = 0; i < n; i ++) a[i] = tmp_a[i];

    delete[] rev;
}

fft_rearrange_logic根据n值获得重排序列,并将输入进行重排。

void fft_merge_logic(complex<double> *a, int n, int inv){

    if(n == 1) return;

    fft_merge_logic(a,       n/2, inv);
    fft_merge_logic(a + n/2, n/2, inv);

    double unit_rad = 2 * pi / n;

    for(int i = 0; i < n/2; i ++){
        complex<double> x(cos(i * unit_rad), inv*sin(i * unit_rad)); // x = omega_n^i
        complex<double> tmp1 = a[i]; 
        complex<double> tmp2 = x * a[i + n/2];
        a[i]       = tmp1 + tmp2;
        a[i + n/2] = tmp1 - tmp2;
    }

}

void fft_logic(complex<double> *a, int n, int inv){

    // 参数:
    // 当inv = 1时,a是系数多项式,n是当前数组长度(2的幂次),函数效果是原地变成点值多项式
    // 当inv = -1时,a是点值多项式,n是当前数组长度(2的幂次),函数效果是原地变成系数多项式,但是所得的系数是n倍,需要在包装的函数中进行调整

    // a中元素的顺序调整为算法要求的顺序

    fft_rearrange_logic(a, n);
    
    // 调整顺序之后的a进行合并

    fft_merge_logic(a, n, inv);
}

在将重排环节统一之后,fft_logic函数如上所示。在这里,fft_merge_logic把上一节中重排后合并的逻辑抽了出来。

完整的,可运行的代码见附录2.

这一节的目的是对于原文改良速度的FFT算法写出一个易于理解,可读性强的实现。

rearrange逻辑:递归变非递归

比较一下附录2和3的代码,发现fft_rearrange_decidesequence_logic被砍掉了,换成了一个简洁的循环。
(另外,按照rev序列调整a的顺序的部分,尊重原文1修改了一下。我觉得两种方式效率差异不大,而且也都比较容易理解。)

因为有一个结论:每个位置分治后的最终位置为其二进制翻转后得到的位置(即老生常谈的蝴蝶算法),所以可以用这个公式直接生成rev,而不需要写递归。

void fft_rearrange_logic(complex<double> *a, int n){

    // 按照算法,把a重新排列为可以直接递归向上的顺序

    // 计算bit: 满足pow(2, bit) = n

    int bit = 0;
    while((1 << bit) < n) bit ++;
    
    // rev: 确定a最终rearrange的位置序列

    int* rev = new int[n];
    for(int i = 0; i < n; i ++){
        rev[i]=(rev[i>>1]>>1)|((i&1)<<(bit-1));
    }

    // 按照rev序列调整a的顺序

    for(int i = 0; i < n; i ++){
        if (i<rev[i])swap(a[i],a[rev[i]]);//不加这条if会交换两次(就是没交换)
    }

    delete[] rev;
}

完整的代码如附录3所示。

去除算法中的所有递归

我们考虑递归的结构(原谅我懒得再画一张新图):

8依赖2个4的完成,每个4依赖其下面2个2的完成,每个2依赖下面2个1的完成。因此,可以通过一个循环,先完成8个1(不需要操作),再完成4个2,再完成2个4,最后完成8.

把这个逻辑用循环写出来,就得到了刨除递归的算法。这个算法和原文1的最终板子是对应的。

void fft_logic(complex<double> *a,int n,int inv){

    // bit: pow(2, bit) = n

    int bit = 0;
    while((1 << bit) < n) bit ++;
    
    // rev: 确定最终rearrange的位置序列

    int* rev = new int[n];
    for(int i = 0; i < n; i ++){
        rev[i]=(rev[i>>1]>>1)|((i&1)<<(bit-1));
    }

    // 按照rev序列调整a的顺序

    for(int i = 0; i < n; i ++){
        if (i<rev[i])swap(a[i],a[rev[i]]);//不加这条if会交换两次(就是没交换)
    }

    delete[] rev;
    
    // 调整顺序之后的a进行合并

    for (int mid=1;mid<n;mid*=2){
        // mid循环内,准备把两个长度为mid的序列合并成2 * mid的序列

        double unit_rad = 2 * pi / (2 * mid); // 单位角幅度值

        for (int i=0;i<n;i+=mid*2){
            // i循环内,把位置在i~i+mid*2位置的两个长度为mid的序列合并成2 * mid的序列

            for (int j = 0; j < mid; j ++){
                // j循环内的逻辑和
                complex<double> x(cos(j * unit_rad), inv*sin(j * unit_rad)); // x = omega_n^i
                complex<double> tmp1 = a[i+j]; 
                complex<double> tmp2 = x * a[i+j+mid];
                a[i+j]     = tmp1 + tmp2;
                a[i+j+mid] = tmp1 - tmp2;
            }
        }
    }
}

完整的代码如附录4所示。

附录

附录1

#include <iostream>
#include <cmath>
#include <complex>
using namespace std;

const double pi = 3.14159265358979323846;
const int N = 8; // 多项式的最大支持位数
complex<double> b[N]; // 用来充当临时调整空间的数组

void fft_logic(complex<double> *a, int n, int inv){

    // 参数:
    // 当inv = 1时,a是系数多项式,n是当前数组长度(2的幂次),函数效果是原地变成点值多项式
    // 当inv = -1时,a是点值多项式,n是当前数组长度(2的幂次),函数效果是原地变成系数多项式,但是所得的系数是n倍,需要在包装的函数中进行调整

    if (n == 1) return; // 为什么?因为omega_1^0=1,点值多项式和系数多项式的表示完全一致。

    // 利用B暂存和写回,把a的顺序调整为 a[0] a[2] .. a[n-2] a[1] a[3] .. a[n-1],前后两半

    for(int i = 0; i < n/2; i ++){
        b[i]       = a[i * 2];
        b[i + n/2] = a[i * 2 + 1];
    }
    for(int i = 0; i < n; i ++)
        a[i] = b[i];

    // 分治求A1和A2

    fft_logic(a, n/2, inv);
    fft_logic(a + n/2, n/2, inv);

    // 通过A1和A2,计算A

    double unit_rad = 2 * pi / n; // 单位角幅度值

    for(int i = 0; i < n/2; i ++){
        complex<double> x(cos(i * unit_rad), inv*sin(i * unit_rad)); // x = omega_n^i 
        complex<double> tmp1 = a[i];
        complex<double> tmp2 = x * a[i + n/2];
        a[i]       = tmp1 + tmp2;
        a[i + n/2] = tmp1 - tmp2;
    }

}

void fft(complex<double> *a, int n){
    // 输入系数多项式及其长度,原地转换为点值多项式
    fft_logic(a, n, 1);
}

void ifft(complex<double> *a, int n){
    // 输入点值多项式及其长度,原地转换为系数多项式
    fft_logic(a, n, -1);
    for(int i = 0; i < n; i ++) 
        a[i] /= n;
}

// 主函数测试

complex<double> A1[N], A2[N]; // 相乘的多项式
complex<double> C[N]; // 相乘结果

int main(){

    // 两个相乘多项式的初始化。注意:两个相乘多项式和最后的乘积多项式需要有相同的项数

    for(int i = 0; i < N; i ++){
        A1[i] = A2[i] = 0;
    }

    A1[0] = 1, A1[1] = 1, A1[2] = 3, A1[3] = 2; // A1(x) = 2x3 + 3x2 +  x + 1
    A2[0] = 4, A2[1] = 3, A2[2] = 5, A2[3] = 1; // A2(x) =  x3 + 5x2 + 3x + 4

    // F,乘,F逆

    fft(A1, N);
    fft(A2, N);

    /*

    for(int i = 0; i < N; i ++) 
        cout << A1[i] << " ";
    cout << endl;

    for(int i = 0; i < N; i ++) 
        cout << A2[i] << " ";
    cout << endl;

    */

    for(int i = 0; i < N; i ++) 
        C[i] = A1[i] * A2[i];

    ifft(C, N);

    // 输出

    

    for(int i = 0; i < N; i ++) 
        cout << C[i] << endl;

    

    return 0;
}

附录2

#include <iostream>
#include <cmath>
#include <complex>
using namespace std;

const double pi = 3.14159265358979323846;
const int N = 8; // 多项式的最大支持位数
complex<double> b[N]; // 用来充当临时调整空间的数组

void fft_merge_logic(complex<double> *a, int n, int inv){

    if(n == 1) return;

    fft_merge_logic(a,       n/2, inv);
    fft_merge_logic(a + n/2, n/2, inv);

    double unit_rad = 2 * pi / n;

    for(int i = 0; i < n/2; i ++){
        complex<double> x(cos(i * unit_rad), inv*sin(i * unit_rad)); // x = omega_n^i
        complex<double> tmp1 = a[i]; 
        complex<double> tmp2 = x * a[i + n/2];
        a[i]       = tmp1 + tmp2;
        a[i + n/2] = tmp1 - tmp2;
    }

}

void fft_rearrange_decidesequence_logic(int *rev, int n){

    // 给定数组rev和数组长度n,函数的功能是将所需的顺序写入数组,比如n = 4时将顺序 0 2 1 3 写入数组rev

    if(n == 1){
        rev[0] = 0;
        return;
    }

    // 获得 n/2 时的顺序,暂时放在rev的后半

    fft_rearrange_decidesequence_logic(rev + n/2, n/2);

    // 利用 n/2 时的顺序构造 n 时的顺序

    for(int i = 0; i < n/2; i ++){
        rev[i] = 2 * rev[i + n/2];
        rev[i + n/2] = 2 * rev[i + n/2] + 1;
    }

}

void fft_rearrange_logic(complex<double> *a, int n){

    // 按照算法,把a重新排列为可以直接递归向上的顺序

    // 计算bit: 满足pow(2, bit) = n

    int bit = 0;
    while((1 << bit) < n) bit ++;
    
    // rev: 确定a最终rearrange的位置序列

    int* rev = new int[n];
    fft_rearrange_decidesequence_logic(rev, n);

    // 按照rev序列调整a的顺序

    complex<double>* tmp_a = new complex<double>[n];

    for(int i = 0; i < n; i ++) tmp_a[i] = a[rev[i]];
    for(int i = 0; i < n; i ++) a[i] = tmp_a[i];

    delete[] rev;
}

void fft_logic(complex<double> *a, int n, int inv){

    // 参数:
    // 当inv = 1时,a是系数多项式,n是当前数组长度(2的幂次),函数效果是原地变成点值多项式
    // 当inv = -1时,a是点值多项式,n是当前数组长度(2的幂次),函数效果是原地变成系数多项式,但是所得的系数是n倍,需要在包装的函数中进行调整

    // a中元素的顺序调整为算法要求的顺序

    fft_rearrange_logic(a, n);
    
    // 调整顺序之后的a进行合并

    fft_merge_logic(a, n, inv);
}

void fft(complex<double> *a, int n){
    // 输入系数多项式及其长度,原地转换为点值多项式
    fft_logic(a, n, 1);
}

void ifft(complex<double> *a, int n){
    // 输入点值多项式及其长度,原地转换为系数多项式
    fft_logic(a, n, -1);
    for(int i = 0; i < n; i ++) 
        a[i] /= n;
}

// 主函数测试

complex<double> A1[N], A2[N]; // 相乘的多项式
complex<double> C[N]; // 相乘结果

int main(){

    // 两个相乘多项式的初始化。注意:两个相乘多项式和最后的乘积多项式需要有相同的项数

    for(int i = 0; i < N; i ++){
        A1[i] = A2[i] = 0;
    }

    A1[0] = 1, A1[1] = 1, A1[2] = 3, A1[3] = 2; // A1(x) = 2x3 + 3x2 +  x + 1
    A2[0] = 4, A2[1] = 3, A2[2] = 5, A2[3] = 1; // A2(x) =  x3 + 5x2 + 3x + 4

    // F,乘,F逆

    fft(A1, N);
    fft(A2, N);

    /*

    for(int i = 0; i < N; i ++) 
        cout << A1[i] << " ";
    cout << endl;

    for(int i = 0; i < N; i ++) 
        cout << A2[i] << " ";
    cout << endl;

    */

    for(int i = 0; i < N; i ++) 
        C[i] = A1[i] * A2[i];

    ifft(C, N);

    // 输出

    for(int i = 0; i < N; i ++) 
        cout << C[i] << endl;

    return 0;
}

附录3

#include <iostream>
#include <cmath>
#include <complex>
using namespace std;

const double pi = 3.14159265358979323846;
const int N = 1024; // 多项式的最大支持位数
complex<double> b[N]; // 用来充当临时调整空间的数组

void fft_merge_logic(complex<double> *a, int n, int inv){

    if(n == 1) return;

    fft_merge_logic(a,       n/2, inv);
    fft_merge_logic(a + n/2, n/2, inv);

    double unit_rad = 2 * pi / n;

    for(int i = 0; i < n/2; i ++){
        complex<double> x(cos(i * unit_rad), inv*sin(i * unit_rad)); // x = omega_n^i
        complex<double> tmp1 = a[i]; 
        complex<double> tmp2 = x * a[i + n/2];
        a[i]       = tmp1 + tmp2;
        a[i + n/2] = tmp1 - tmp2;
    }

}

void fft_rearrange_logic(complex<double> *a, int n){

    // 按照算法,把a重新排列为可以直接递归向上的顺序

    // 计算bit: 满足pow(2, bit) = n

    int bit = 0;
    while((1 << bit) < n) bit ++;
    
    // rev: 确定a最终rearrange的位置序列

    int* rev = new int[n];
    for(int i = 0; i < n; i ++){
        rev[i]=(rev[i>>1]>>1)|((i&1)<<(bit-1));
    }

    // 按照rev序列调整a的顺序

    for(int i = 0; i < n; i ++){
        if (i<rev[i])swap(a[i],a[rev[i]]);//不加这条if会交换两次(就是没交换)
    }

    delete[] rev;
}

void fft_logic(complex<double> *a, int n, int inv){

    // 参数:
    // 当inv = 1时,a是系数多项式,n是当前数组长度(2的幂次),函数效果是原地变成点值多项式
    // 当inv = -1时,a是点值多项式,n是当前数组长度(2的幂次),函数效果是原地变成系数多项式,但是所得的系数是n倍,需要在包装的函数中进行调整

    // a中元素的顺序调整为算法要求的顺序

    fft_rearrange_logic(a, n);
    
    // 调整顺序之后的a进行合并

    fft_merge_logic(a, n, inv);
}

void fft(complex<double> *a, int n){
    // 输入系数多项式及其长度,原地转换为点值多项式
    fft_logic(a, n, 1);
}

void ifft(complex<double> *a, int n){
    // 输入点值多项式及其长度,原地转换为系数多项式
    fft_logic(a, n, -1);
    for(int i = 0; i < n; i ++) 
        a[i] /= n;
}

// 主函数测试

complex<double> A1[N], A2[N]; // 相乘的多项式
complex<double> C[N]; // 相乘结果

int main(){

    // 两个相乘多项式的初始化。注意:两个相乘多项式和最后的乘积多项式需要有相同的项数

    for(int i = 0; i < N; i ++){
        A1[i] = A2[i] = 0;
    }

    A1[0] = 1, A1[1] = 1, A1[2] = 3, A1[3] = 2; // A1(x) = 2x3 + 3x2 +  x + 1
    A2[0] = 4, A2[1] = 3, A2[2] = 5, A2[3] = 1; // A2(x) =  x3 + 5x2 + 3x + 4

    // F,乘,F逆

    fft(A1, N);
    fft(A2, N);

    /*

    for(int i = 0; i < N; i ++) 
        cout << A1[i] << " ";
    cout << endl;

    for(int i = 0; i < N; i ++) 
        cout << A2[i] << " ";
    cout << endl;

    */

    for(int i = 0; i < N; i ++) 
        C[i] = A1[i] * A2[i];

    ifft(C, N);

    // 输出

    for(int i = 0; i < N; i ++) 
        cout << C[i] << endl;

    return 0;
}

附录4

#include <iostream>
#include <cmath>
#include <complex>
using namespace std;

const double pi = 3.14159265358979323846;
const int N = 8; // 多项式的最大支持位数
complex<double> b[N]; // 用来充当临时调整空间的数组

void fft_logic(complex<double> *a,int n,int inv){

    // bit: pow(2, bit) = n

    int bit = 0;
    while((1 << bit) < n) bit ++;
    
    // rev: 确定最终rearrange的位置序列

    int* rev = new int[n];
    for(int i = 0; i < n; i ++){
        rev[i]=(rev[i>>1]>>1)|((i&1)<<(bit-1));
    }

    // 按照rev序列调整a的顺序

    for(int i = 0; i < n; i ++){
        if (i<rev[i])swap(a[i],a[rev[i]]);//不加这条if会交换两次(就是没交换)
    }

    delete[] rev;
    
    // 调整顺序之后的a进行合并

    for (int mid=1;mid<n;mid*=2){
        // mid循环内,准备把两个长度为mid的序列合并成2 * mid的序列

        double unit_rad = 2 * pi / (2 * mid); // 单位角幅度值

        for (int i=0;i<n;i+=mid*2){
            // i循环内,把位置在i~i+mid*2位置的两个长度为mid的序列合并成2 * mid的序列

            for (int j = 0; j < mid; j ++){
                // j循环内的逻辑和
                complex<double> x(cos(j * unit_rad), inv*sin(j * unit_rad)); // x = omega_n^i
                complex<double> tmp1 = a[i+j]; 
                complex<double> tmp2 = x * a[i+j+mid];
                a[i+j]     = tmp1 + tmp2;
                a[i+j+mid] = tmp1 - tmp2;
            }
        }
    }
}

void fft(complex<double> *a, int n){
    // 输入系数多项式及其长度,原地转换为点值多项式
    fft_logic(a, n, 1);
}

void ifft(complex<double> *a, int n){
    // 输入点值多项式及其长度,原地转换为系数多项式
    fft_logic(a, n, -1);
    for(int i = 0; i < n; i ++) 
        a[i] /= n;
}

// 主函数测试

complex<double> A1[N], A2[N]; // 相乘的多项式
complex<double> C[N]; // 相乘结果

int main(){

    // 两个相乘多项式的初始化。注意:两个相乘多项式和最后的乘积多项式需要有相同的项数

    for(int i = 0; i < N; i ++){
        A1[i] = A2[i] = 0;
    }

    A1[0] = 1, A1[1] = 1, A1[2] = 3, A1[3] = 2; // A1(x) = 2x3 + 3x2 +  x + 1
    A2[0] = 4, A2[1] = 3, A2[2] = 5, A2[3] = 1; // A2(x) =  x3 + 5x2 + 3x + 4

    // F,乘,F逆

    fft(A1, N);
    fft(A2, N);

    /*

    for(int i = 0; i < N; i ++) 
        cout << A1[i] << " ";
    cout << endl;

    for(int i = 0; i < N; i ++) 
        cout << A2[i] << " ";
    cout << endl;

    */

    for(int i = 0; i < N; i ++) 
        C[i] = A1[i] * A2[i];

    ifft(C, N);

    // 输出

    

    for(int i = 0; i < N; i ++) 
        cout << C[i] << endl;

    

    return 0;
}

参考


  1. 十分简明易懂的FFT(快速傅里叶变换) ↩︎ ↩︎ ↩︎ ↩︎ ↩︎ ↩︎

  2. 快速傅里叶变换(FFT)和逆快速傅里叶变换(IFFT) ↩︎

  3. C++ complex类 ↩︎

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值