题目大多来自于2022杭电多校与牛客多校,且绝大多数为NTT,学了一个暑假的,从初学者的角度理解题目,理论学习方面是从oiwiki上学来的,板子用的是jls板子。
FFT与NTT的应用
FFT是使用复数来进行卷积的,对于两个多项式
A=a0x^0+a1x^1+a2x^2 +.......+anx^n
B=b0x^0+b1x^1+b2x^3 +.......+bmx^n
我们将A和B的系数分别表示成 A数组和B数组,
A={a0,a1,a2......an}
B={b0,b1,b2,....bm
若我们要计算A*B,一个朴素的做法是进行一个二重循环,计算每一项的值后生成一个新的数组C
具体代码为
for(int i=0;i<=n;i++)
{
for(int j=0;j<=m;j++)
{
C[i+j]+=A[i]*B[j];
}
}
然而大多数情况下我们遇到的题目是不会让你O(n^2)解决这个问题的,这个时候,我们就要利用到FFT和NTT了,它们能在O(nlogn)内计算出C数组。
具体的原理oiwiki上讲的很好了,你以为我不打算讲原理了吗,这么有魅力的东西,我还是想稍微讲讲的,这里只讲解FFT,因为NTT是带模意义下的FFT,原理本质上是一样的
一.原理篇
学习路线1:
1.复数篇
复数加减乘除高中我们已经学过了,我们引入复数的目的是为了单位根
单位根:x^n=1在复数意义下是n次复数根,即x有n个解,即为n次单位根,
我们定义w(n,i)=exp(2Π*i/n)(0 <= i < n) 这样我们就有了n个点
注:关于exp :设复数z=x+iy ,函数f(z)=e^x(cosy+isiny),定义exp(z) =f(z)
对于exp有几个重要性质
1.exp(z1+z2)=exp(z1)*exp(z2) (后面将会用到)
2.exp z是以2Πi为周期的周期函数,即可以在复数单位圆上表示出每一个点
我们可以发现这n个解对应就是将单元圆均等分成n个的n个点
如图,当n等于8时对应的是图上的八个点
这样我们能够知道单位根有三个性质
(1)w(n,n)=1; 回到原点
(2)w(n,k)=w(2n,2k) 等倍增长
(3)w(2n,k+n)=-w(2n,k) 对角线
注释:后面有用处
学习路线2:
快速傅里叶变换 - OI Wiki (oi-wiki.org)
快速傅里叶变换(FFT)
(1)具体实现步骤
1.将A,B变成点值表示法
2.得到C的点值表示法
3.得到C的系数表示法
1.离散傅里叶变换(DFT):将A,B变成点值表示法
离散傅里叶变换是将多项式从系数表示法变成点值表示法的算法。
对于一个n阶多项式A,当我们知道其中的n个点时,我们就可以列出来n个等式,
a0x1^0+a1x1^1+a2x1^2 +......anx1^n =A(x1)
a0x2^0+a1x2^1+a2x2^2 +......anx2^n =A(x2)
.......
a0xn^0+a1xn^1+a2xn^2 +......anxn^n =A(xn)
而由线代知识这个方程是有解的,所以我们可以由n个点来得到多项式的n个系数
那么多项式就可以在系数表示法和点值表示法之间相互转换
而DFT就是实现多项式从系数表示法到点值表示法的算法,时间复杂度为O(nlogn)
FFT是基于分治的算法
f(x)=a0+a1x^1+a2x^2+.....a3x^3 ...... anx^m
=(a0+a2x^2+a4x^4+...) +x(a1+a3x^2+a5x^4.....)
那么我们将他的奇数项和偶数项分开得到
设G(x)=a0+a2x+a4x^2+.... 偶数保留不变
H(x)=a1+a3x+a5x^2 +..... 奇数项除以x
则f(x)=G(x^2) +x*H(x^2)
我们前面又提到过
1.偶数单位根的性质3 w(n,i)=-w(n,i+n/2)
2.设复数z=x+iy ,函数f(z)=e^x(cosy+isiny),定义exp(z) =f(z)
exp函数的性质exp(z1*z2)=exp(z1)*exp(z2),
3.w(n,i)=exp(2Π*i/n)
所以w和exp具有相同的性质
并且我们知道G(x^2) 和H(x^2)为偶函数,
f(w(n,k)) =G( ( w(n,k) )^2 ) +w(n,k) * H( ( w(n,k) )^2 )
=G( w(n,2k) ) +w(n,k)*H( w(n,2k) )
=G( w(n/2,k) +w(n,k)*H( w(n/2,k) )
f(w(n,k+n/2)) =G( ( w(n,2k+n) ) +w(n,k+n/2) * H( ( w(n,2k+n) )
=G( w(n,2k) ) - w(n,k)*H( w(n,2k) )
=G( w(n/2,k) - w(n,k)*H( w(n/2,k) )
这样的话我们就可以通过 G和H一次性得到两个点
并且我们我们可以分治去递归每一层的G和H来得到新的点,这样复杂度就是nlogn的了
注:这里要注意通过分治必须得保证长度位2的次方,所以我们要将数组补齐到最近的2的次方
但通过递归实现会多了常数,于是有了位逆序变换来优化掉这一层常数
位逆序变换也称蝴蝶变换,我们知道分治递归去处理的话,是将偶数放左边,奇数放右边,而一直递归到相当于对数组进行一个重新排序,那么我们是否可以先得到最终数组的排序,然后一直递推回去得到最终答案呢,答案是可以的
通过研究得到这个位逆序变换得到的最终数组是有规律的,它的下标即为它原本下标的二进制数翻转过来的数,比如3为 011,翻转后为110 =6,那么x3就放在第6个位置,具体实现这个排序的算法在下面放的代码中
到此我们就把dft的最终版本实现了
放上jls的dft代码
constexpr double PI = std::atan2(0, -1);
std::vector<int> rev;
std::vector<std::complex<double>> roots {0, 1};
void dft(std::vector<std::complex<double>> &a) {
int n = a.size();
//位逆序变换
if (int(rev.size()) != n) {
int k = __builtin_ctz(n) - 1;
rev.resize(n);
for (int i = 0; i < n; ++i)
rev[i] = rev[i >> 1] >> 1 | (i & 1) << k;
}
for (int i = 0; i < n; ++i)
if (rev[i] < i)
swap(a[i], a[rev[i]]);
//递推实现dft
if (int(roots.size()) < n) {
int k = __builtin_ctz(roots.size());
roots.resize(n);
while ((1 << k) < n) {
std::complex<double> e = {cos(PI / (1 << k)), sin(PI / (1 << k))};
for (int i = 1 << (k - 1); i < (1 << k); ++i) {
roots[2 * i] = roots[i];
roots[2 * i + 1] = roots[i] * e;
}
++k;
}
}
for (int k = 1; k < n; k *= 2) {
for (int i = 0; i < n; i += 2 * k) {
for (int j = 0; j < k; ++j) {
auto u = a[i + j], v = a[i + j + k] * roots[k + j];
a[i + j] = u + v;
a[i + j + k] = u - v;
}
}
}
}
2.前面我们已经知道DFT可以得到AB的点值表示法,接下来我们是要得到C的点值表示法
假设A*B=C
假设我们知道A上的n个点
(x1,y1) (x2,y2) (x3,y3) .... (xn,yn)
B上的n个点
(x1,y1') (x2,y2') (x3.y3') .... (xn,yn')
注意这n个点的横坐标相同
那么我们C=A*B,
即C上的n个点为 (x1,y1*y1') (x2,y2*y2') (x3,y3*y3') .....(xn,yn*yn')
这样我们是可以通过on的复杂度将A,B上的n个点转化到C上的n个点的
3.得到C的点值表示法后,我们要得到C的系数表示法
这里用的是离散傅里叶反变换(idft)
idft的作用是将多项式的点值表达式变成系数表达式,本身是一个求逆的过程
这个逆矩阵是范德蒙德矩阵,逆矩阵非常特殊,这也是fft巧妙地原因,我们可以继续利用dft的模板,只需要将每一项取倒数再除以变换长度n后,做一遍dft,就实现我们的目的了,到此我们就实现了fft
void idft(std::vector<std::complex<double>> &a) {
int n = a.size();
reverse(a.begin() + 1, a.end());
dft(a);
for (int i = 0; i < n; ++i)
a[i] /= n;
}
jls FFT整个板子
constexpr double PI = std::atan2(0, -1);
std::vector<int> rev;
std::vector<std::complex<double>> roots {0, 1};
void dft(std::vector<std::complex<double>> &a) {
int n = a.size();
if (int(rev.size()) != n) {
int k = __builtin_ctz(n) - 1;
rev.resize(n);
for (int i = 0; i < n; ++i)
rev[i] = rev[i >> 1] >> 1 | (i & 1) << k;
}
for (int i = 0; i < n; ++i)
if (rev[i] < i)
swap(a[i], a[rev[i]]);
if (int(roots.size()) < n) {
int k = __builtin_ctz(roots.size());
roots.resize(n);
while ((1 << k) < n) {
std::complex<double> e = {cos(PI / (1 << k)), sin(PI / (1 << k))};
for (int i = 1 << (k - 1); i < (1 << k); ++i) {
roots[2 * i] = roots[i];
roots[2 * i + 1] = roots[i] * e;
}
++k;
}
}
for (int k = 1; k < n; k *= 2) {
for (int i = 0; i < n; i += 2 * k) {
for (int j = 0; j < k; ++j) {
auto u = a[i + j], v = a[i + j + k] * roots[k + j];
a[i + j] = u + v;
a[i + j + k] = u - v;
}
}
}
}
void idft(std::vector<std::complex<double>> &a) {
int n = a.size();
reverse(a.begin() + 1, a.end());
dft(a);
for (int i = 0; i < n; ++i)
a[i] /= n;
}
std::vector<int64_t> operator*(std::vector<int64_t> a, std::vector<int64_t> b) {
int sz = 1, tot = a.size() + b.size() - 1;
while (sz < tot)
sz *= 2;
std::vector<std::complex<double>> ca(sz), cb(sz);
copy(a.begin(), a.end(), ca.begin());
copy(b.begin(), b.end(), cb.begin());
dft(ca);
dft(cb);
for (int i = 0; i < sz; ++i)
ca[i] *= cb[i];
idft(ca);
a.resize(tot);
for (int i = 0; i < tot; ++i)
a[i] = std::floor(ca[i].real() + 0.5);
return a;
}
ntt是取模意义下的fft,需要有原根
jls ntt板子
using namespace std;
using pii = pair<int, int>;
using i64 = long long;
constexpr int P = 998244353, G = 3;
template <class T>
T power(T a, int b)
{
T res = 1;
for (; b; b >>= 1, a *= a)
if (b & 1)
res *= a;
return res;
}
int norm(int x)
{
if (x < 0) x += P;
if (x >= P) x -= P;
return x;
}
struct Z
{
int x;
Z(int x = 0) : x(norm(x)) {}
int val() const
{
return x;
}
Z operator-() const
{
return Z(norm(P - x));
}
Z inv() const
{
assert(x != 0);
return power(*this, P - 2);
}
Z &operator*=(const Z &rhs)
{
x = i64(x) * rhs.x % P;
return *this;
}
Z &operator+=(const Z &rhs)
{
x = norm(x + rhs.x);
return *this;
}
Z &operator-=(const Z &rhs)
{
x = norm(x - rhs.x);
return *this;
}
Z &operator/=(const Z &rhs)
{
return *this *= rhs.inv();
}
friend Z operator*(const Z &lhs, const Z &rhs)
{
Z res = lhs;
res *= rhs;
return res;
}
friend Z operator+(const Z &lhs, const Z &rhs)
{
Z res = lhs;
res += rhs;
return res;
}
friend Z operator-(const Z &lhs, const Z &rhs)
{
Z res = lhs;
res -= rhs;
return res;
}
friend Z operator/(const Z &lhs, const Z &rhs)
{
Z res = lhs;
res /= rhs;
return res;
}
};
vector<int> rev;
vector<Z> roots{0, 1};
void dft(vector<Z> &a)
{
int n = a.size();
if (int(rev.size()) != n)
{
int k = __builtin_ctz(n) - 1;
rev.resize(n);
for (int i = 0; i < n; i++)
rev[i] = rev[i >> 1] >> 1 | (i & 1) << k;
}
for (int i = 0; i < n; i++)
if (rev[i] < i)
swap(a[i], a[rev[i]]);
if (int(roots.size()) < n)
{
int k = __builtin_ctz(roots.size());
roots.resize(n);
while ((1 << k) < n)
{
Z e = power(Z(G), (P - 1) >> (k + 1));
for (int i = 1 << (k - 1); i < (1 << k); i++)
{
roots[2 * i] = roots[i];
roots[2 * i + 1] = roots[i] * e;
}
k++;
}
}
for (int k = 1; k < n; k *= 2)
for (int i = 0; i < n; i += 2 * k)
for (int j = 0; j < k; j++)
{
Z u = a[i + j];
Z v = a[i + j + k] * roots[k + j];
a[i + j] = u + v;
a[i + j + k] = u - v;
}
}
void idft(vector<Z> &a)
{
int n = a.size();
reverse(a.begin() + 1, a.end());
dft(a);
Z inv = (1 - P) / n;
for (int i = 0; i < n; i++)
a[i] *= inv;
}
vector<Z> operator*(const vector<Z> &a, const vector<Z> &b) //NTT
{
int sz = 1, tot = a.size() + b.size() - 1;
while (sz < tot) sz *= 2;
vector<Z> ca(sz), cb(sz);
copy(a.begin(), a.end(), ca.begin());
copy(b.begin(), b.end(), cb.begin());
dft(ca); dft(cb);
for (int i = 0; i < sz; ++i) ca[i] = ca[i] * cb[i];
idft(ca);
ca.resize(tot);
return ca;
}