(原理篇)FFT与NTT (快速傅里叶变换)摘自oiwi,仅方便理解

题目大多来自于2022杭电多校与牛客多校,且绝大多数为NTT,学了一个暑假的,从初学者的角度理解题目,理论学习方面是从oiwiki上学来的,板子用的是jls板子。

FFT与NTT的应用

FFT是使用复数来进行卷积的,对于两个多项式

A=a0x^0+a1x^1+a2x^2 +.......+anx^n

B=b0x^0+b1x^1+b2x^3 +.......+bmx^n

我们将A和B的系数分别表示成 A数组和B数组,

A={a0,a1,a2......an}

B={b0,b1,b2,....bm

若我们要计算A*B,一个朴素的做法是进行一个二重循环,计算每一项的值后生成一个新的数组C

具体代码为

for(int i=0;i<=n;i++)
{
	for(int j=0;j<=m;j++)
	{
		C[i+j]+=A[i]*B[j];
	}
}

然而大多数情况下我们遇到的题目是不会让你O(n^2)解决这个问题的,这个时候,我们就要利用到FFT和NTT了,它们能在O(nlogn)内计算出C数组。

具体的原理oiwiki上讲的很好了,你以为我不打算讲原理了吗,这么有魅力的东西,我还是想稍微讲讲的,这里只讲解FFT,因为NTT是带模意义下的FFT,原理本质上是一样的

一.原理篇

学习路线1:

复数 - OI Wiki (oi-wiki.org)

1.复数篇

复数加减乘除高中我们已经学过了,我们引入复数的目的是为了单位根

单位根:x^n=1在复数意义下是n次复数根,即x有n个解,即为n次单位根,

我们定义w(n,i)=exp(2Π*i/n)(0 <= i < n) 这样我们就有了n个点 

注:关于exp :设复数z=x+iy ,函数f(z)=e^x(cosy+isiny),定义exp(z) =f(z)

对于exp有几个重要性质

1.exp(z1+z2)=exp(z1)*exp(z2)  (后面将会用到)

2.exp z是以2Πi为周期的周期函数,即可以在复数单位圆上表示出每一个点

我们可以发现这n个解对应就是将单元圆均等分成n个的n个点

如图,当n等于8时对应的是图上的八个点 

这样我们能够知道单位根有三个性质

(1)w(n,n)=1;  回到原点

(2)w(n,k)=w(2n,2k)  等倍增长   

(3)w(2n,k+n)=-w(2n,k)  对角线

注释:后面有用处

学习路线2:

快速傅里叶变换 - OI Wiki (oi-wiki.org)

快速傅里叶变换(FFT)

(1)具体实现步骤

1.将A,B变成点值表示法

2.得到C的点值表示法

3.得到C的系数表示法

1.离散傅里叶变换(DFT)将A,B变成点值表示法

离散傅里叶变换是将多项式从系数表示法变成点值表示法的算法。

对于一个n阶多项式A,当我们知道其中的n个点时,我们就可以列出来n个等式,

a0x1^0+a1x1^1+a2x1^2 +......anx1^n =A(x1)

a0x2^0+a1x2^1+a2x2^2 +......anx2^n =A(x2)

.......

a0xn^0+a1xn^1+a2xn^2 +......anxn^n =A(xn)

而由线代知识这个方程是有解的,所以我们可以由n个点来得到多项式的n个系数

那么多项式就可以在系数表示法和点值表示法之间相互转换

而DFT就是实现多项式从系数表示法到点值表示法的算法,时间复杂度为O(nlogn)

FFT是基于分治的算法

f(x)=a0+a1x^1+a2x^2+.....a3x^3 ...... anx^m

     =(a0+a2x^2+a4x^4+...) +x(a1+a3x^2+a5x^4.....)

那么我们将他的奇数项和偶数项分开得到

设G(x)=a0+a2x+a4x^2+.... 偶数保留不变

   H(x)=a1+a3x+a5x^2 +..... 奇数项除以x

则f(x)=G(x^2) +x*H(x^2)

我们前面又提到过

1.偶数单位根的性质3 w(n,i)=-w(n,i+n/2)

2.设复数z=x+iy ,函数f(z)=e^x(cosy+isiny),定义exp(z) =f(z)

  exp函数的性质exp(z1*z2)=exp(z1)*exp(z2),

3.w(n,i)=exp(2Π*i/n)

所以w和exp具有相同的性质

并且我们知道G(x^2) 和H(x^2)为偶函数,

              f(w(n,k)) =G( ( w(n,k) )^2  ) +w(n,k) * H( ( w(n,k) )^2 )

                           =G( w(n,2k) ) +w(n,k)*H( w(n,2k) )

                           =G( w(n/2,k)  +w(n,k)*H( w(n/2,k) )

      f(w(n,k+n/2)) =G( ( w(n,2k+n)  ) +w(n,k+n/2) * H( ( w(n,2k+n)  )

                           =G( w(n,2k) ) - w(n,k)*H( w(n,2k) )

                           =G( w(n/2,k)  - w(n,k)*H( w(n/2,k) )

       这样的话我们就可以通过 G和H一次性得到两个点

并且我们我们可以分治去递归每一层的G和H来得到新的点,这样复杂度就是nlogn的了

注:这里要注意通过分治必须得保证长度位2的次方,所以我们要将数组补齐到最近的2的次方

但通过递归实现会多了常数,于是有了位逆序变换来优化掉这一层常数

 

位逆序变换也称蝴蝶变换,我们知道分治递归去处理的话,是将偶数放左边,奇数放右边,而一直递归到相当于对数组进行一个重新排序,那么我们是否可以先得到最终数组的排序,然后一直递推回去得到最终答案呢,答案是可以的

通过研究得到这个位逆序变换得到的最终数组是有规律的,它的下标即为它原本下标的二进制数翻转过来的数,比如3为 011,翻转后为110 =6,那么x3就放在第6个位置,具体实现这个排序的算法在下面放的代码中

 到此我们就把dft的最终版本实现了

放上jls的dft代码

constexpr double PI = std::atan2(0, -1);
std::vector<int> rev;
std::vector<std::complex<double>> roots {0, 1};
void dft(std::vector<std::complex<double>> &a) {
    int n = a.size();
    //位逆序变换
    if (int(rev.size()) != n) {
        int k = __builtin_ctz(n) - 1;
        rev.resize(n);
        for (int i = 0; i < n; ++i)
            rev[i] = rev[i >> 1] >> 1 | (i & 1) << k;
    }
    for (int i = 0; i < n; ++i)
        if (rev[i] < i)
            swap(a[i], a[rev[i]]);
    //递推实现dft
    if (int(roots.size()) < n) {
        int k = __builtin_ctz(roots.size());
        roots.resize(n);
        while ((1 << k) < n) {
            std::complex<double> e = {cos(PI / (1 << k)), sin(PI / (1 << k))};
            for (int i = 1 << (k - 1); i < (1 << k); ++i) {
                roots[2 * i] = roots[i];
                roots[2 * i + 1] = roots[i] * e;
            }
            ++k;
        }
    }
    for (int k = 1; k < n; k *= 2) {
        for (int i = 0; i < n; i += 2 * k) {
            for (int j = 0; j < k; ++j) {
                auto u = a[i + j], v = a[i + j + k] * roots[k + j];
                a[i + j] = u + v;
                a[i + j + k] = u - v;
            }
        }
    }
}

2.前面我们已经知道DFT可以得到AB的点值表示法,接下来我们是要得到C的点值表示法

假设A*B=C

假设我们知道A上的n个点

(x1,y1) (x2,y2) (x3,y3) .... (xn,yn)

B上的n个点

(x1,y1') (x2,y2') (x3.y3') .... (xn,yn') 

注意这n个点的横坐标相同

那么我们C=A*B,

即C上的n个点为 (x1,y1*y1') (x2,y2*y2') (x3,y3*y3') .....(xn,yn*yn')

这样我们是可以通过on的复杂度将A,B上的n个点转化到C上的n个点的

3.得到C的点值表示法后,我们要得到C的系数表示法

这里用的是离散傅里叶反变换(idft)

idft的作用是将多项式的点值表达式变成系数表达式,本身是一个求逆的过程

 这个逆矩阵是范德蒙德矩阵,逆矩阵非常特殊,这也是fft巧妙地原因,我们可以继续利用dft的模板,只需要将每一项取倒数再除以变换长度n后,做一遍dft,就实现我们的目的了,到此我们就实现了fft

void idft(std::vector<std::complex<double>> &a) {
    int n = a.size();
    reverse(a.begin() + 1, a.end());
    dft(a);
    for (int i = 0; i < n; ++i)
        a[i] /= n;
}

jls FFT整个板子

constexpr double PI = std::atan2(0, -1);
std::vector<int> rev;
std::vector<std::complex<double>> roots {0, 1};
void dft(std::vector<std::complex<double>> &a) {
    int n = a.size();
    if (int(rev.size()) != n) {
        int k = __builtin_ctz(n) - 1;
        rev.resize(n);
        for (int i = 0; i < n; ++i)
            rev[i] = rev[i >> 1] >> 1 | (i & 1) << k;
    }
    for (int i = 0; i < n; ++i)
        if (rev[i] < i)
            swap(a[i], a[rev[i]]);
    if (int(roots.size()) < n) {
        int k = __builtin_ctz(roots.size());
        roots.resize(n);
        while ((1 << k) < n) {
            std::complex<double> e = {cos(PI / (1 << k)), sin(PI / (1 << k))};
            for (int i = 1 << (k - 1); i < (1 << k); ++i) {
                roots[2 * i] = roots[i];
                roots[2 * i + 1] = roots[i] * e;
            }
            ++k;
        }
    }
    for (int k = 1; k < n; k *= 2) {
        for (int i = 0; i < n; i += 2 * k) {
            for (int j = 0; j < k; ++j) {
                auto u = a[i + j], v = a[i + j + k] * roots[k + j];
                a[i + j] = u + v;
                a[i + j + k] = u - v;
            }
        }
    }
}
void idft(std::vector<std::complex<double>> &a) {
    int n = a.size();
    reverse(a.begin() + 1, a.end());
    dft(a);
    for (int i = 0; i < n; ++i)
        a[i] /= n;
}
std::vector<int64_t> operator*(std::vector<int64_t> a, std::vector<int64_t> b) {
    int sz = 1, tot = a.size() + b.size() - 1;
    while (sz < tot)
        sz *= 2;
    std::vector<std::complex<double>> ca(sz), cb(sz);
    copy(a.begin(), a.end(), ca.begin());
    copy(b.begin(), b.end(), cb.begin());
    dft(ca);
    dft(cb);
    for (int i = 0; i < sz; ++i)
        ca[i] *= cb[i];
    idft(ca);
    a.resize(tot);
    for (int i = 0; i < tot; ++i)
        a[i] = std::floor(ca[i].real() + 0.5);
    return a;
}

ntt是取模意义下的fft,需要有原根

jls ntt板子

using namespace std;
using pii = pair<int, int>;
using i64 = long long;
 
constexpr int P = 998244353, G = 3;
 
template <class T>
T power(T a, int b)
{
    T res = 1;
    for (; b; b >>= 1, a *= a)
        if (b & 1)
            res *= a;
    return res;
}
 
int norm(int x)
{
    if (x < 0) x += P;
    if (x >= P) x -= P;
    return x;
}
struct Z
{
    int x;
    Z(int x = 0) : x(norm(x)) {}
    int val() const
    {
        return x;
    }
    Z operator-() const
    {
        return Z(norm(P - x));
    }
    Z inv() const
    {
        assert(x != 0);
        return power(*this, P - 2);
    }
    Z &operator*=(const Z &rhs)
    {
        x = i64(x) * rhs.x % P;
        return *this;
    }
    Z &operator+=(const Z &rhs)
    {
        x = norm(x + rhs.x);
        return *this;
    }
    Z &operator-=(const Z &rhs)
    {
        x = norm(x - rhs.x);
        return *this;
    }
    Z &operator/=(const Z &rhs)
    {
        return *this *= rhs.inv();
    }
    friend Z operator*(const Z &lhs, const Z &rhs)
    {
        Z res = lhs;
        res *= rhs;
        return res;
    }
    friend Z operator+(const Z &lhs, const Z &rhs)
    {
        Z res = lhs;
        res += rhs;
        return res;
    }
    friend Z operator-(const Z &lhs, const Z &rhs)
    {
        Z res = lhs;
        res -= rhs;
        return res;
    }
    friend Z operator/(const Z &lhs, const Z &rhs)
    {
        Z res = lhs;
        res /= rhs;
        return res;
    }
};
 
vector<int> rev;
vector<Z> roots{0, 1};
void dft(vector<Z> &a)
{
    int n = a.size();
 
    if (int(rev.size()) != n)
    {
        int k = __builtin_ctz(n) - 1;
        rev.resize(n);
        for (int i = 0; i < n; i++)
            rev[i] = rev[i >> 1] >> 1 | (i & 1) << k;
    }
 
    for (int i = 0; i < n; i++)
        if (rev[i] < i)
            swap(a[i], a[rev[i]]);
    if (int(roots.size()) < n)
    {
        int k = __builtin_ctz(roots.size());
        roots.resize(n);
        while ((1 << k) < n)
        {
            Z e = power(Z(G), (P - 1) >> (k + 1));
            for (int i = 1 << (k - 1); i < (1 << k); i++)
            {
                roots[2 * i] = roots[i];
                roots[2 * i + 1] = roots[i] * e;
            }
            k++;
        }
    }
    for (int k = 1; k < n; k *= 2)
        for (int i = 0; i < n; i += 2 * k)
            for (int j = 0; j < k; j++)
            {
                Z u = a[i + j];
                Z v = a[i + j + k] * roots[k + j];
                a[i + j] = u + v;
                a[i + j + k] = u - v;
            }
}
void idft(vector<Z> &a)
{
    int n = a.size();
    reverse(a.begin() + 1, a.end());
    dft(a);
    Z inv = (1 - P) / n;
    for (int i = 0; i < n; i++)
        a[i] *= inv;
}
vector<Z> operator*(const vector<Z> &a, const vector<Z> &b) //NTT
{
    int sz = 1, tot = a.size() + b.size() - 1;
    while (sz < tot) sz *= 2;
    vector<Z> ca(sz), cb(sz);
    copy(a.begin(), a.end(), ca.begin());
    copy(b.begin(), b.end(), cb.begin());
    dft(ca); dft(cb);
    for (int i = 0; i < sz; ++i) ca[i] = ca[i] * cb[i];
    idft(ca);
    ca.resize(tot);
    return ca;
}

 

  • 2
    点赞
  • 10
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值