前置知识
DFT离散傅里叶变换
原根
数学上已经证明,在复数域内,具有循环卷积特性的唯一变换是DFT,因此提出了以数论为基础的具有循环卷积性质的快速数论变换。
下面开始介绍NTT,快速数论变换
我们知道在复数域中有单位根:
ω
n
=
1
\omega^n=1
ωn=1,如果我们选取一个素数p,显然有
ω
n
≡
1
(
m
o
d
p
)
\omega^n\equiv1(mod\ p)
ωn≡1(mod p)设g为p的一个原根(素数一定有原根,即
g
p
−
1
≡
1
(
m
o
d
p
)
g^{p-1}\equiv 1(mod\ p)
gp−1≡1(mod p))显然
ω
n
=
g
p
−
1
n
\omega_n=g^{\frac{p-1}{n}}
ωn=gnp−1,设
g
n
=
g
p
−
1
n
g_n=g^{\frac{p-1}{n}}
gn=gnp−1。
所以对于DFT变换公式:
y
k
=
∑
j
=
0
n
−
1
a
j
w
n
k
j
y_k=\sum_{j=0}^{n-1}a_jw_n^{kj}
yk=j=0∑n−1ajwnkj
代入有NTT变换公式:
y
k
=
∑
j
=
0
n
−
1
a
j
g
n
k
j
y_k=\sum_{j=0}^{n-1}a_jg^{kj}_n
yk=j=0∑n−1ajgnkj
对逆DFT公式同理有逆NTT变换公式:
a
j
=
1
n
∑
j
=
0
n
−
1
y
k
g
n
−
k
j
a_j=\frac{1}{n}\sum_{j=0}^{n-1}y_kg^{-kj}_n
aj=n1j=0∑n−1ykgn−kj
不难发现,这个牵扯到一个问题,
g
p
−
1
n
k
j
g^{\frac{p-1}{n}kj}
gnp−1kj不一定是整数,NTT主要优点就是没有FFT的浮点精度问题,且可以取模,如果
g
p
−
1
n
k
j
g^{\frac{p-1}{n}kj}
gnp−1kj不是整数,NTT就丧失了它的意义。所以我们要找到一个合适的素数p,令
g
p
−
1
n
k
j
g^{\frac{p-1}{n}kj}
gnp−1kj是整数,即保证n | (p-1),因为我们在编程的时候为方便起见,n的取值是2的整数幂,在此基础上,我们的必须保证素数p可以写成这种形式:
p
=
c
∗
2
k
+
1
。
p=c*2^k+1。
p=c∗2k+1。其中
2
k
2^k
2k必须大于n,
这里提供一个素数:
p
=
479
∗
2
21
+
1
=
1004535809
p=479*2^{21}+1=1004535809
p=479∗221+1=1004535809,依照上述所讲,对于
n
<
2
21
n<2^{21}
n<221的多项式,都可用这个模数正确求解。
下面是模板代码:
P3803 【模板】多项式乘法(FFT)
const int N = 1e6+5;
const int G = 3;
const ll mod = (479<<21) + 1;
ll f[N], g[N];
int rev[N];
int len, lim = 1;
void init(int n)
{
while(lim <= n) lim <<= 1, len++;
for (int i = 0; i < lim; i++) rev[i] = (rev[i >> 1] >> 1) | ((i & 1) << (len-1));
}
ll _pow(ll a, ll b) {}
void NTT(ll *a, int op)//op为1是FFT运算,-1是FFT逆运算
{
for (int i = 0; i < lim; i++)
if (i < rev[i]) swap(a[i], a[rev[i]]);
for (int dep = 1; dep <= log2(lim); dep++)
{
int m = 1 << dep;
ll wn = _pow(G, (mod-1) / m);
for (int k = 0; k < lim; k += m)
{
ll w = 1;
for (int j = 0; j < m / 2; j++)
{
ll t = w * (a[k + j + m / 2] % mod) % mod;
ll u = a[k + j] % mod;
a[k + j] =(u + t) % mod;
a[k + j + m / 2] =((u - t) % mod + mod ) % mod;
w = w * wn % mod;
}
}
}
if (op == -1)
{
for (int i = 1; i < lim / 2; i++)
swap(a[i], a[lim-i]);
ll inv = _pow(lim, mod-2);
for (int i = 0; i < lim; i++)
a[i] = a[i] * inv % mod;
}
}
int main()
{
int n, m;
cin >> n >> m;
init(n + m);
for (int i = 0; i <= n; i++) scanf("%lld", &f[i]);
for (int i = 0; i <= m; i++) scanf("%lld", &g[i]);
NTT(g, 1); NTT(f, 1);
for (int i = 0; i <= lim; i++) f[i] = f[i] * g[i] % mod;
NTT(f, -1);
for (int i = 0; i <= n + m; i++)
printf("%lld ", f[i]);
return 0;
}