https://www.cnblogs.com/zzqsblog/p/5665654.html
下面其实是写给自己查阅的,真正要学可以看上面的。
OI里的fft并没有那么神奇,可以简单理解为加速卷积(多项式)的工具。
设n次多项式A,B,本来需要n^2的时间求其乘积(卷积),使用fft可以加速成n log n.
#n次单位复数根
满足w^n=1的复数。
由复数乘法性质,幅角相加,长度相乘可知,w其实就是将单位圆均分为n份的那n个复数。
记为w1,…wn.
显然
w
j
=
c
o
s
(
j
∗
2
p
i
/
n
)
+
s
i
n
(
j
∗
2
p
i
/
n
)
∗
i
wj=cos(j*2pi/n) + sin(j*2pi/n)*i
wj=cos(j∗2pi/n)+sin(j∗2pi/n)∗i
#DFT
离散傅里叶变换(求点值表达,下文点值表达特指x取遍n次单位复数根的点值表达)
下文称a次多项式的次数界为a+1。
为了方便,次数界应补足为最近的2的幂。 (高位系数设0)
求一个次数界为n的多项式,当x取n单位复数根时(w0…wn-1)的n个值。
主要思想是分治,拆分为奇数幂次与偶数幂次,由于单位复数根的特殊性质(平方后减半)
最终式子:
A
[
i
]
=
A
0
[
i
]
+
A
1
[
i
]
∗
W
i
,
i
<
n
/
2
A[i]~~~~~~~~~~~=A0[i] + A1[i] * W^i,i<n/2
A[i] =A0[i]+A1[i]∗Wi,i<n/2
A
[
i
+
n
/
2
]
=
A
0
[
i
]
−
A
1
[
i
]
∗
W
i
,
i
<
n
/
2
A[i+n/2]=A0[i] -A1[i] * W^i,i<n/2
A[i+n/2]=A0[i]−A1[i]∗Wi,i<n/2
#逆DFT
已知次数界为n的多项式在n次单位复数根下的点值表达,求系数表达。
一发推导后可以发现,只要把n次主根取
w
−
1
w^{-1}
w−1,按照原先做dft,再将最后结果/次数界即可。
最终结果可能有误差,需要加上一个eps.
#蝶形变换
小常数实现fft的方法。
先按二进制翻转(上限取次数界-1),然后从左到右做。
#NTT (数论变换)
模意义下的dft/idft
在模一个费马模数的前提下(
P
=
k
2
a
+
1
P=k2^a+1
P=k2a+1,比如998244353,g=3),我们可以将n次单位主根单位根替换为
g
P
−
1
n
g^{\frac {P-1} {n}}
gnP−1,其中g是P的原根。
且满足n<=2^a
假设一个数g是P的原根,那么g^i mod P的结果两两不同,且有 1 < g < P , 0 < i < P 1<g<P,0<i<P 1<g<P,0<i<P,
逆dft中,负指数需要用逆元。注意long long问题
#调试方法
无???
如何验证对错:
对拍
观察fft后虚部是否为0。
对一个数列dft,idft看是否前后一致
#DFT
#include <cstdio>
#include <iostream>
#include <cmath>
#include <complex>
typedef double db;
typedef long long ll;
#define com complex<db>
using namespace std;
const int N=1e5+10;
const db pi=acos(-1);
int n,h[N*6],M;
com q[N*6],r[N*6];
com a[N*6];
void dft(com *src,int sig) {
for (int i=0; i<M; i++) a[h[i]]=src[i]; //蝶形变换的准备
for (int m=2; m<=M; m<<=1) { //正在求的组大小
int half=m>>1;
for (int i=0; i<half; i++) { //求A[i]与A[i+k],先枚举这个方便处理主根
com w=com(cos(i*2*pi/m) , sig * sin(i*2*pi/m));
//必须一步一求,不然精度会出锅。
//由于转移到m,因此按照式子是m次根。
for (int j=i; j<M; j+=m) { //第几组
int k=j+half;
com u=a[j], v=a[k]*w;
a[j]=u+v,a[k]=u-v;
}
}
}
for (int i=0; i<M; i++) src[i]=a[i];
}
int main() {
//freopen("3617.in","r",stdin);
cin>>n;
for (int i=0; i<n; i++) scanf("%lf",&q[i].real());
for (M=1; M<3*n; M<<=1);
for (int i=0; i<2*n-1; i++) {
if (i==n-1) r[i]=0; else
r[i]=pow(i-(n-1),-2) * (i<n-1?-1:1);
}
for (int i=0; i<M; i++) h[i]=(h[i>>1]>>1) + ((i&1) * (M>>1));
//以次数界-1为长度,翻转二进制
dft(q,1); dft(r,1);
for (int i=0; i<M; i++) q[i]=q[i]*r[i];
dft(q,-1);
for (int i=n-1; i<n+n-1; i++) printf("%lf\n",q[i].real() / M);
//不要忘记除次数界!!!
}
#NTT
#include <cstdio>
#include <iostream>
using namespace std;
const int N=4e5+10,mo=998244353,g=3;
typedef long long ll;
int n,m,M;
int A[N],B[N],h[N];
ll w[N], iw[N];
ll ksm(ll x,ll y) {
if (y==0) return 1;
if (y==1) return x;
ll t=ksm(x,y>>1);
return t*t%mo*ksm(x,y&1)%mo;
}
void ntt(int *a,int sz,int sig) {
for (int i = 1; i < sz; i++)
h[i] = (h[i>>1]>>1) + (i & 1) * (sz >> 1);
for (int i = 0; i <sz; i++)
if (h[i]<i) swap(a[i],a[h[i]]);
for (int m = 1; m < sz; m<<=1) {
int td = M / (m<<1);
for (int i = 0; i < sz; i += (m<<1)) {
for (int j = 0; j < m; j++) {
ll T = a[i+j+m] * (sig == 1 ? w[td * j] : iw[td * j]) % mo;
a[i+j+m] = (a[i+j] - T) % mo;
a[i+j] = (a[i+j] + T) % mo;
}
}
}
}
int main() {
freopen("test.in","r",stdin);
cin>>n>>m;;
for (int i=0; i<=n; i++) scanf("%d",&A[i]);
for (int i=0; i<=m; i++) scanf("%d",&B[i]);
for (M=1; M<=n+m; M<<=1);
for (int i=1; i<M; i++)
h[i]=(h[i>>1]>>1) + (i&1) * (M>>1);
ll ww = ksm(3, (mo - 1) / M);
iw[0] = w[0] = 1;
for (int i = 1; i < M; i++) w[i] = w[i-1] * ww % mo;
ww = ksm(ww, mo - 2);
for (int i = 1; i < M; i++) iw[i] = iw[i-1] * ww % mo;
ntt(A,M,1);
ntt(B,M,1);
for (int i=0; i<M; i++) A[i]=(ll)A[i]*B[i]%mo;
ntt(A,M,-1);
ll cs=ksm(M,mo-2);
for (int i=0; i<=n+m; i++) printf("%lld ",(A[i]*cs%mo+mo)%mo);
}