题目链接
思路
相信大家都已经学会了
F
F
T
FFT
FFT,若不会,请看这篇博客
一个同学写的,自认为不错
在我们的
F
F
T
FFT
FFT中,我们使用了复数来进行计算
但是我们可以发现复数的乘法时间复杂度是
O
(
4
)
O(4)
O(4)
而
d
o
u
b
l
e
double
double的计算则更加增添了时间复杂度
同时因为浮点数计算
\sqrt{\ \ \ }
的精度会有误差
导致我们最终的所有部分的和反而与完整的
360
360
360°
那么我们不妨试想,若是在系数为整数的情况下
或者需要取模的时候,我们该如何来解决呢?
NTT
于是我们就可以介绍今天的主角了
NTT —— 快速数论变换
是一种建立在数论上的对FFT的优化
(或者可以说是取模运算的FFT)
只不过由于FFT用到是复数
而且double在做了大量的实数运算之后
精度损失较大
而我们的NTT就可以在模意义下
快速做这样的一个多项式乘法
NTT常数小一些
一般这个模数被认为是
x
∗
2
k
+
1
x * 2^k+1
x∗2k+1
原根
接下来我们介绍一个东西——原根
设
m
m
m是正整数,
a
a
a是整数
若
a
a
a模
m
m
m的阶等于
φ
(
m
)
\varphi(m)
φ(m)
则称
a
a
a为模
m
m
m的一个原根。
假设一个数
g
g
g是
P
P
P的原根
那么
g
i
m
o
d
P
g^i\ mod\ P
gi mod P的结果两两不同
且有
1
<
g
<
P
1<g<P
1<g<P,
0
<
i
<
P
0<i<P
0<i<P
归根到底就是
g
P
−
1
≡
1
(
m
o
d
P
)
g^{P-1} \equiv 1 (mod\ P)
gP−1≡1(mod P)
当且仅当指数为
P
−
1
P-1
P−1的时候成立(
P
P
P是素数)。
简单来说,
g
i
m
o
d
p
≠
g
j
m
o
d
p
g^i\ mod\ p \neq g^j mod\ p
gi mod p=gjmod p (
p
p
p为素数)
其中
i
≠
j
i \ne j
i=j 且
i
,
j
i, j
i,j介于
1
1
1至
(
p
−
1
)
(p-1)
(p−1)之间
则
g
g
g为
p
p
p的原根。
提供一种暴力的原根的求法
int calc(int x)// 求原根
{
if (x == 2)
return 1;
for (int i = 2;i ;i++)
{
bool f = 1;
for (int j = 2;j * j < x; j++)
if (qkpow(i, (x - 1) / j, x) == 1)
{
f = 0;
break;
}
if (f)
return i;
}
}
下面是一个常用模数的表
转载自Rayment_cc大佬的博客
NTT
而有了原根以后
我们就可以将复数替换成原根的
m
o
d
−
1
2
∗
i
\frac{mod - 1}{2\ *\ i}
2 ∗ imod−1次方
然后将FFT套上去
就可以实现了
至于为什么,请看这里
于是我们就搞定了原根与单位复根的转换
code
版本1:FFT
#include <cstdio>
#include <cstring>
#include <algorithm>
#include <cstdlib>
#include <cctype>
#include <map>
#include <queue>
#include <set>
#include <vector>
#include <iostream>
#include <stack>
#include <complex>
#include <cmath>
#define pk putchar(' ')
#define ph puts("")
#pragma GCC optimize(2)
using namespace std;
typedef long long ll;
typedef complex<double> cpd;
template <class T>
void rd(T &x)
{
x = 0;
int f = 1;
char c = getchar();
while (!isdigit(c)) {if (c == '-') f = -1; c = getchar();}
while (isdigit(c)) x = (x << 3) + (x << 1) + (c ^ 48), c = getchar();
x *= f;
}
template <class T>
void pt(T x)
{
if (x < 0)
putchar('-'), x = (~x) + 1;
if (x > 9)
pt(x / 10);
putchar(x % 10 ^ 48);
}
template <class T>
T Max(T a, T b)
{
return a > b ? a : b;
}
template <class T>
T Min(T a, T b)
{
return a < b ? a : b;
}
template <class T>
T Fabs(T x)
{
return x > 0 ? x : -x;
}
template <class T>
void Swap(T &x, T &y)
{
T t = x;
x = y;
y = t;
}
const int N = (1 << 18) + 5;
const double pi = acos(-1.0);
int n, R[N], l, lim, m;
cpd a[N], b[N];
void fft(cpd *a, int opt)
// opt = 1 => dft
// opt = -1 => idft
{
for (int i = 0; i < lim; i++)
if (i < R[i])
Swap(a[i], a[R[i]]);
for (int i = 1;i < lim; i <<= 1)
{
cpd x{cos(pi / i), opt * sin(pi / i)}, y{1, 0};
for (int j = 0;j < lim; j += (i << 1), y = {1, 0})
for (int k = 0;k < i; k++, y *= x)
{
cpd p = a[j + k], q = y * a[j + k + i];
a[j + k] = p + q;
a[j + k + i] = p - q;
}
}
}
// void ptcp(cpd x)
// {
// printf ("%f+%fi\n", x.real(), x.imag());
// return;
// }
int main()
{
rd(n), rd(m);
for (int i = 0, x;i <= n; i++)
rd(x), a[i] = x;
for (int i = 0, x;i <= m; i++)
rd(x), b[i] = x;
for (lim = 1;lim <= n + m; lim <<= 1)
l++;
for (int i = 0;i < lim; i++)
R[i] = (R[i >> 1] >> 1) | ((i & 1) << (l - 1));
fft(a, 1);
fft(b, 1);
for (int i = 0;i <= lim; i++)
a[i] = a[i] * b[i];
fft(a, -1);
for (int i = 0;i <= n + m; i++)
pt(int((a[i].real() + 0.5) / lim)), pk;
return 0;
}
版本2:NTT
#include <cstdio>
#include <cstring>
#include <algorithm>
#include <cstdlib>
#include <cctype>
#include <map>
#include <queue>
#include <set>
#include <vector>
#include <iostream>
#include <stack>
#include <complex>
#include <cmath>
#define pk putchar(' ')
#define ph puts("")
#pragma GCC optimize(2)
using namespace std;
typedef long long ll;
template <class T>
void rd(T &x)
{
x = 0;
int f = 1;
char c = getchar();
while (!isdigit(c)) {if (c == '-') f = -1; c = getchar();}
while (isdigit(c)) x = (x << 3) + (x << 1) + (c ^ 48), c = getchar();
x *= f;
}
template <class T>
void pt(T x)
{
if (x < 0)
putchar('-'), x = (~x) + 1;
if (x > 9)
pt(x / 10);
putchar(x % 10 ^ 48);
}
template <class T>
T Max(T a, T b)
{
return a > b ? a : b;
}
template <class T>
T Min(T a, T b)
{
return a < b ? a : b;
}
template <class T>
T Fabs(T x)
{
return x > 0 ? x : -x;
}
template <class T>
void Swap(T &x, T &y)
{
T t = x;
x = y;
y = t;
}
const int N = (1 << 18) + 5, mod = 998244353;
int n, R[N], l, lim, m, a[N], b[N], g;
int qkpow(int x, int y, int mod)
{
int res = 1;
while (y)
{
if (y & 1)
res = 1ll * res * x % mod;
y >>= 1;
x = 1ll * x * x % mod;
}
return res;
}
int calc(int x)// 求原根
{
if (x == 2)
return 1;
for (int i = 2;i ;i++)
{
bool f = 1;
for (int j = 2;j * j < x; j++)
if (qkpow(i, (x - 1) / j, x) == 1)
{
f = 0;
break;
}
if (f)
return i;
}
}
void ntt(int *a, int opt)
// opt = 1 => dft
// opt = -1 => idft
{
for (int i = 0; i < lim; i++)
if (i < R[i])
Swap(a[i], a[R[i]]);
for (int i = 1;i < lim; i <<= 1)
{
int tmp = qkpow(g, (mod - 1) / (i << 1), mod);
if (opt == -1)
tmp = qkpow(tmp, mod - 2, mod);
for (int j = 0, y = 1;j < lim; j += (i << 1), y = 1)
for (int k = 0;k < i; k++, y = 1ll * y * tmp % mod)
{
int p = a[j + k], q = 1ll * y * a[j + k + i] % mod;
a[j + k] = (p + q) % mod;
a[j + k + i] = (p - q + mod) % mod;
}
}
if (opt == -1)
{
int inv = qkpow(lim, mod - 2, mod);
for (int i = 0;i < lim; i++)
a[i] = 1ll * inv * a[i] % mod;
}
}
int main()
{
rd(n), rd(m);
for (int i = 0;i <= n; i++)
rd(a[i]);
for (int i = 0;i <= m; i++)
rd(b[i]);
for (lim = 1;lim <= n + m; lim <<= 1)
l++;
for (int i = 0;i < lim; i++)
R[i] = (R[i >> 1] >> 1) | ((i & 1) << (l - 1));
g = calc(mod);
ntt(a, 1);
ntt(b, 1);
for (int i = 0;i <= lim; i++)
a[i] = 1ll * a[i] * b[i] % mod;
ntt(a, -1);
for (int i = 0;i <= n + m; i++)
pt(a[i]), pk;
return 0;
}
版本3(函数版NTT,带求乘法逆)
#include <cstdio>
#include <cstring>
#include <algorithm>
#include <cstdlib>
#include <cctype>
#include <map>
#include <queue>
#include <set>
#include <vector>
#include <iostream>
#include <stack>
#include <complex>
#include <cmath>
#define pk putchar(' ')
#define ph puts("")
#pragma GCC optimize(2)
using namespace std;
typedef long long ll;
template <class T>
void rd(T &x)
{
x = 0;
ll B = 1;
char c = getchar();
while (!isdigit(c)) {if (c == '-') B = -1; c = getchar();}
while (isdigit(c)) x = (x << 3) + (x << 1) + (c ^ 48), c = getchar();
x *= B;
}
template <class T>
void pt(T x)
{
if (x < 0)
putchar('-'), x = (~x) + 1;
if (x > 9)
pt(x / 10);
putchar(x % 10 ^ 48);
}
template <class T>
T Max(T a, T b)
{
return a > b ? a : b;
}
template <class T>
T Min(T a, T b)
{
return a < b ? a : b;
}
template <class T>
T Fabs(T x)
{
return x > 0 ? x : -x;
}
template <class T>
void Swap(T &x, T &y)
{
T t = x;
x = y;
y = t;
}
const int N = (1 << 19) + 5, mod = 998244353, p = 3;
int n, R[N], c[N], a[N], b[N], m;
int qkpow(int x, int y, int mod)
{
int res = 1;
while (y)
{
if (y & 1)
res = 1ll * res * x % mod;
y >>= 1;
x = 1ll * x * x % mod;
}
return res;
}
void ntt(int *a, int opt, int lim, int mod)
// opt = 1 => dft
// opt = -1 => idft
{
int p_inv = opt == -1 ? qkpow(p, mod - 2, mod) : 0;
for (int i = 0; i < lim; i++)
if (i < R[i])
Swap(a[i], a[R[i]]);
for (int i = 1;i < lim; i <<= 1)
{
int tmp = qkpow(opt == 1 ? p : p_inv, (mod - 1) / (i << 1), mod);
for (int j = 0, y = 1;j < lim; j += (i << 1), y = 1)
for (int k = 0;k < i; k++, y = 1ll * y * tmp % mod)
{
int p = a[j + k], q = 1ll * y * a[j + k + i] % mod;
a[j + k] = (p + q) % mod;
a[j + k + i] = (p - q + mod) % mod;
}
}
if (opt == -1)
{
int inv = qkpow(lim, mod - 2, mod);
for (ll i = 0;i < lim; i++)
a[i] = 1ll * inv * a[i] % mod;
}
}
void mul(int *a, int *b, int n, int m)
{
int lim, l;
for (l = 0, lim = 1;lim < n + m; l++, lim <<= 1);
for (int i = 0;i < lim; R[i] = (R[i >> 1] >> 1) | ((i & 1) << (l - 1)), i++);
ntt(a, 1, lim, mod);
ntt(b, 1, lim, mod);
for (int i = 0;i < lim; i++)
a[i] = 1ll * a[i] * b[i] % mod;
ntt(a, -1, lim, mod);
}
void inv(int *a, int *b, int siz)
{
if (siz == 1)
{
b[0] = qkpow(a[0], mod - 2, mod);
return;
}
inv(a, b, (siz + 1) >> 1);
int lim, l;
for (l = 0, lim = 1;lim < siz + siz; l++, lim <<= 1);
for (int i = 0;i < lim; R[i] = (R[i >> 1] >> 1) | ((i & 1) << (l - 1)), i++);
for (int i = 0;i < siz; i++)
c[i] = a[i];
for (int i = siz;i < lim; i++)
b[i] = 0;
ntt(c, 1, lim, mod);
ntt(b, 1, lim, mod);
for (int i = 0;i < lim; i++)
b[i] = (b[i] * 2 % mod - 1ll * b[i] * b[i] % mod * c[i] % mod + mod) % mod;
ntt(b, -1, lim, mod);
for (int i = siz;i < lim; i++)
b[i] = 0;
}
int main()
{
// freopen("testdata.in", "r", stdin);
// freopen("test.out", "w", stdout);
rd(n), rd(m);
n++, m++;
for (int i = 0;i < n; i++)
rd(a[i]);
for (int j = 0;j < m; j++)
rd(b[j]);
mul(a, b, n, m);
for (int i = 0;i < n + m - 1; i++)
pt(a[i]), pk;
return 0;
}
不要吝啬您的小鼠标,右上角的小手手,请点击下