前置知识:
阶:对于互质整数
a
,
n
a,n
a,n ,满足
a
r
≡
1
m
o
d
n
a^r \equiv 1\mod n
ar≡1modn 的最小的
r
r
r 就是
a
a
a 模
n
n
n 的阶。
原根:对于正整数
n
n
n ,若整数
a
a
a 模
n
n
n 的阶等于
φ
(
n
)
\varphi(n)
φ(n) ,则称
a
a
a 为模
n
n
n 的一个原根。
快速数论变换
N
T
T
NTT
NTT,实际上就是利用一些特殊的质数(例如
998244353
=
119
×
2
23
+
1
998244353=119\times2^{23}+1
998244353=119×223+1,这些质数的特征都是可以可以表示
q
×
2
k
+
1
q \times 2 ^ k + 1
q×2k+1 的形式)进行膜意义下的运算代替浮点复数运算来保证精度,原理是质数原根和复数单位根在DFT运算中具有相同的性质。
下列出我们在
F
F
T
FFT
FFT 中利用的单位复根的性质:
- ω n 0 = ω n n = 1 \omega_n^0=\omega_n^n=1 ωn0=ωnn=1
- 若 i ≠ j m o d n i\not=j\mod n i=jmodn ,有 ω n i ≠ ω n j \omega_n^i\not=\omega_n^j ωni=ωnj
- ω d n d k = ω n k \omega_{dn}^{dk}=\omega_n^k ωdndk=ωnk
- ω n k + n 2 = − ω n k \omega_n^{k+\frac n2}=-\omega_n^k ωnk+2n=−ωnk
- ∑ i = 1 n − 1 ( ω n j − k ) i = [ k = = j ] n \sum_{i=1}^{n-1}(\omega_n^{j-k})^i=[k==j]n ∑i=1n−1(ωnj−k)i=[k==j]n
对于质数
p
p
p ,设
g
g
g 是
p
p
p 的一个原根,则有
g
0
≡
g
φ
(
p
)
≡
1
m
o
d
p
g^0\equiv g^{\varphi(p)}\equiv 1\mod p
g0≡gφ(p)≡1modp ,故对于
i
≢
j
m
o
d
φ
(
p
)
i\not \equiv j\mod \varphi(p)
i≡jmodφ(p) ,有
g
i
≢
g
j
m
o
d
p
g^i\not\equiv g^j\mod p
gi≡gjmodp
对于质数
p
=
k
×
2
N
+
1
p=k\times2^N+1
p=k×2N+1 ,设
g
n
≡
g
p
−
1
n
m
o
d
p
g_n\equiv g^{\frac{p-1}n}\mod p
gn≡gnp−1modp ,这个
g
n
g_n
gn 同样拥有着上述单位复根的
5
5
5 条重要性质。
我们之所以取模数
p
=
k
×
2
N
+
1
p=k\times2^N+1
p=k×2N+1 ,是因为我们要保证
p
−
1
n
\frac{p-1}{n}
np−1 这个指数是一个整数,由于
F
F
T
FFT
FFT 的过程是一个不断把区间大小缩小一半来处理的过程,所以我们可以保证
n
n
n 总是
2
2
2 的幂次,所以模数必须足够多的
2
2
2 ,所以
p
p
p 需要取成
k
×
2
N
+
1
k\times2^N+1
k×2N+1 而且
N
N
N 需要尽可能的大一些
所以在写
N
T
T
NTT
NTT 时,只需要把
F
F
T
FFT
FFT 中的单位复根换成原根即可
具体见代码:
int n,m,len = 2,L,rev[maxn<<1],a[maxn<<1],b[maxn<<1],c[maxn<<1];
const int mod = 998244353;//mod原根为3
const int g = 3;//原根
ll qpow(ll a,ll b)
{
ll res = 1;
while(b)
{
if(b&1) res = res*a%mod;
a = a*a%mod;
b >>= 1;
}
return res;
}
void get_rev()
{
for(int i = 0;i < len;i++)
rev[i] = (rev[i>>1]>>1)|(len>>1)*(i&1);
// for(int i = 0;i < len;i++) rev[i] = (rev[i>>1]>>1)|((i&1)<<(L-1));
}
void ntt(int *a,int dft)
{
for(int i = 0;i < len;i++)
if(i < rev[i]) swap(a[i],a[rev[i]]);//不加这条if会交换两次(就是没交换)
for(int mid = 1;mid < len;mid <<= 1)//mid是准备合并序列的长度的二分之一
{
ll W = qpow(g,(mod-1)/(mid<<1));
if(dft == -1) W = qpow(W,mod-2);
for(int i = 0;i < len;i += mid<<1)//mid*2是准备合并序列的长度,i是合并到了哪一位
{
ll w = 1;
for(int j = i;j < mid+i;j++,w = w*W%mod)//只扫左半部分,得到右半部分的答案
{
int x = a[j];
int y = w*a[j+mid]%mod;
a[j] = (x+y)%mod;
a[j+mid] = ((x-y)%mod+mod)%mod;
}
}
}
if(dft == -1)
{
int tmp = qpow(len,mod-2);
for(int i = 0;i < len;i++) a[i] = (ll)a[i]*tmp%mod;
}
}