FFT && NTT
fft 这么重要的东西还是记录一下吧(雾)
基本概念
多项式的点值表示法和系数表示法:
- fft 实际上是使用 O ( n l o g n ) O(nlogn) O(nlogn) 的时间将一个用系数表示的多项式转换成它的点值表示的算法
- 多项式的点值表示法可以相互转化- d f t 和 i d f t dft和idft dft和idft
点值表示法
点值表示法就是用一个 n + 1 n+1 n+1维度的向量来表示一个 n n n度的多项式
比方 A ( x ) = a 0 + a 1 x + a 2 x 2 . . . + a n x n A(x) =a_0+a_1x+a_2x^2...+a_nx^n A(x)=a0+a1x+a2x2...+anxn
可以表示成
( ( x 0 , A ( x 0 ) ) , ( x 1 , A ( x 1 ) , . . . , ( x n , A ( x n ) ) ) ) ((x_0,A(x_0)), (x_1,A(x_1),...,(x_n,A(x_n)))) ((x0,A(x0)),(x1,A(x1),...,(xn,A(xn))))
如何实现dft?
首先我们要选择有特征的 x x x来时这一过程简化(大雾,感觉自己根本不适合写引导类的东西,直接上结论吧)
- 使用 x n = 1 x^n=1 xn=1的一组复数根来作为 x x x
ω n 0 , ω n 1 , ω n 2 , . . ω n n − 1 \omega_n^0,\omega_n^1,\omega_n^2,..\omega_n^{n-1} ωn0,ωn1,ωn2,..ωnn−1
- 通过分治来使单个的计算简化
关于单位复根的一些性质
ω n k = ω 2 n 2 k ω n n = 1 ω 2 n k + n = − ω 2 n k \omega_n^k=\omega_{2n}^{2k}\\ \omega_n^n=1\\\omega_{2n}^{k+n}=-\omega_{2n}^k ωnk=ω2n2kωnn=1ω2nk+n=−ω2nk
想一下复平面内的旋转或者欧拉公式就可以了
另外 ω n i = c o s ( 2 π i n ) + s i n ( 2 π i n ) i = e 2 π i n i \omega_n^i=cos(\frac{2\pi i}{n})+sin(\frac{2\pi i}{n})i=e^{\frac{2\pi i}{n}i} ωni=cos(n2πi)+sin(n2πi)i=en2πii
分治过程的推导
对于度不是二的幂次的多项式我们可以把它改成同阶的多项式
举个例子吧 f ( x ) = a 0 + a 1 x + a 2 x 2 + a 3 x 3 + . . + a 7 x 7 f(x)=a_0+a_1x+a_2x^2+a_3x^3+..+a_7x^7 f(x)=a0+a1x+a2x2+a3x3+..+a7x7
改写
f ( x ) = ( a 0 + a 2 x 2 + a 4 x 4 + a 6 x 6 ) + x ( a 1 + a 3 x 2 + a 5 x 4 + a 7 x 6 ) = g ( x 2 ) + x × h ( x 2 ) f(x)=(a_0+a_2x^2+a_4x^4+a_6x^6)+x(a_1+a_3x^2+a_5x^4+a_7x^6)\\=g(x^2)+x\times h(x^2) f(x)=(a0+a2x2+a4x4+a6x6)+x(a1+a3x2+a5x4+a7x6)=g(x2)+x×h(x2)
这里 h ( x ) 和 g ( x ) h(x)和g(x) h(x)和g(x)都是度为4的多项式
代入 ω n k \omega_n^k ωnk
f ( ω n k ) = h ( ω n 2 k ) + ω n 2 k × g ( ω n 2 k ) = h ( ω n / 2 k ) + ω n k × g ( ω n / 2 k ) f(\omega_n^k)=h(\omega_n^{2k})+\omega_n^{2k}\times g(\omega_n^{2k})\\=h(\omega_{n/2}^k)+\omega_n^k\times g(\omega_{n/2}^{k}) f(ωnk)=h(ωn2k)+ωn2k×g(ωn2k)=h(ωn/2k)+ωnk×g(ωn/2k)
对于 f ( ω n n / 2 + k ) = h ( ω n / 2 k ) − ω n k × g ( ω n / 2 k ) f(\omega_n^{n/2+k})=h(\omega_{n/2}^{k})-\omega_n^k\times g(\omega_{n/2}^{k}) f(ωnn/2+k)=h(ωn/2k)−ωnk×g(ωn/2k)
过程。。利用上面的性质搞一下就出来了
注意到后面两个的形式可以递归去算
递归形式的板子
// mashiroyuki
// dft 递归形式的板子
#include <complex>
#include <cmath>
#define pi M_PI
typedef std::complex<double> cp;
const int maxn=1e5+10;
cp b[maxn]; //参数M_PI在内置库里面有
void dft(cp*a, int n, int inv) {
//inv表示是否求共轭复数inv=1表示dft,inv=-1表示idft
if(n==1) return;
int mid=n>>1;
rep(i,0,n-1) b[i]=a[i];
rep(i,0,mid-1) a[i]=b[i*2], a[i+mid]=b[2*i+1];
dft(a,mid,inv), dft(a+mid,mid,inv); //分治
rep(i,0,mid-1) {
cp x(cos(2*i*pi/n),inv*sin(2*i*pi)/n);
b[i]=a[i]+x*a[i+mid];
b[i+mid]=a[i]-x*a[i+mid];
}
rep(i,0,n-1) a[i]=b[i];
}
非递归形式dft
非递归形式加入了蝴蝶变换。
计算方式:
只给出
O
(
n
)
O(n)
O(n)(其实不会另一种,逃)
//每个i考虑高n-1位然后自然可以用已经处理出来的rev[i>>1]去计算
rev[i]=rev[i>>1]>>1|(i&1)<<bit-1
mashiroyuki
// dft 非递归形式的板子
#include <complex>
#include <cmath>
#define pi M_PI
typedef std::complex<double> cp;
const int maxn=1e5+10;
cp b[maxn]; //参数M_PI
int rev[maxn];
void dft(cp* f, int n, int inv) {
int bit=1;
while((1<<bit)<n) bit++;
rep(i,0,n-1) {
rev[i]=rev[i>>1]>>1|(i&1)<<(bit-1);
if(i<rev[i]) swap(a[i],a[rev[i]]); //这一步要注意处理
}
//根据蝴蝶变换后的序列向上合并,枚举合并序列长度for(int mid=1;mid<n;mid*=2) {
//枚举每一段的起点
cp t(cos(pi/mid), inv*sin(pi/mid));
for(int i=0;i<n;i+=mid*2) {
cp w(1,0);
for(int j=0;j<mid;j++,w*=t) {
cp x=f[i+j], y=f[i+j+mid]*w;
f[i+j]=x+y, f[i+j+mid]=x-y; //步骤基本还是一样
}
}
}
}
除了蝴蝶序列优化没有什么区别,就是模拟了一下递归过程
如何求出fft
- 把两个多项式扩充到最进的二的幂次的长度,同时保证是度相等
- 分别做一遍dft然后直接相乘
- 做一遍idft
代码
cp a[MAXN],b[MAXN];
int c[MAXN];
fft(a,n,1),fft(b,n,1);//1系数转点值
rep(i,0,n-1)a[i]*=b[i];
fft(a,n,-1);//-1点值转系数
rep(i,0,n-1)c[i]=(int)(a[i].real()/n+0.5);//注意精度
// fft
typedef complex<double> cp;
const int maxn=4e6+10;
#define pi M_PI
cp a[maxn], b[maxn];
int rev[maxn], c[maxn];
void dft(cp* f, int n, int inv) {
int bit=0;
while((1<<bit)<n) bit++;
rep(i,0,n-1) {
rev[i]=(rev[i>>1]>>1)|((i&1)<<(bit-1));
if(i<rev[i]) swap(f[i],f[rev[i]]);
}
//根据蝴蝶变换后的序列向上合并,枚举合并序列长度
for(int mid=1;mid<n;mid*=2) {
//枚举每一段的起点
cp t(cos(pi/mid), inv*sin(pi/mid));
for(int i=0;i<n;i+=mid*2) {
cp w(1,0);
for(int j=0;j<mid;j++,w*=t) {
cp x=f[i+j], y=f[i+j+mid]*w;
f[i+j]=x+y, f[i+j+mid]=x-y; //步骤基本还是一样
}
}
}
}
void fft(int n, int m) {
n=m+n+1;
int bit=1;
while((1<<bit)<n) bit++;
n=1<<bit;
dft(a,n,1); dft(b,n,1);
rep(i,0,n-1) a[i]*=b[i];
dft(a,n,-1);
rep(i,0,n-1) c[i]=(int)(a[i].real()/n+0.5);
}
// fft
修改后的模板的几点说明:
- idft 后得到的不是直接的系数,要除以度
- 长度要是最接近m+n 的二的幂
fft 结语
蒟蒻终于也入门fft了,www
$$$$
部分内容借鉴于OI wiki