前言
FFT是很优美的一个方法,可以解决多项式中的许多问题,NTT则是在模p意义下的多项式乘积,
两者十分相似,所不同的是两者所在数域的单位根不同,
一个是
ω
=
e
2
π
n
i
\omega = e^{\frac{2\pi }{n}i}
ω=en2πi,另一个是
g
,
(
p
∤
g
i
,
∀
i
)
g,(p\nmid g^i,\forall i)
g,(p∤gi,∀i)
下面给出非递归的代码模板,用到了二进制位逆序变换、蝴蝶操作
值得注意的是,遇见1004535809和998244353要敏感些,极大可能就是NTT的题
一、例题
题目链接 洛谷-P3803
二、思路及代码
1.思路
很单纯的一道FFT/NTT,直接套模板
2.代码
FFT版本:
#include <cmath>
#include <iostream>
#define int long long
#define pi acos(-1.0)
using namespace std;
const int maxn = 1e7 + 7;
int n, m;
int N, len;
int rev[maxn];
struct complex {
double r, i;
complex() {}
complex(double a, double b) : r(a), i(b) {}
} a[maxn], b[maxn], ans[maxn];
complex operator+(const complex a, const complex b) {
return complex(a.r + b.r, a.i + b.i);
}
complex operator-(const complex a, const complex b) {
return complex(a.r - b.r, a.i - b.i);
}
complex operator*(const complex a, const complex b) {
return complex(a.r * b.r - a.i * b.i, a.i * b.r + a.r * b.i);
}
void FFT(complex a[], int inv) {
for (int i = 0; i < N; i++) // 0 1 2 3 4 5 6 7 -> 0 4 2 6 1 5 3 7
if (i < rev[i]) swap(a[i], a[rev[i]]);
for (int k = 1; k < N; k <<= 1) {
complex wn(cos(pi / k), inv * sin(pi / k));
for (int i = 0; i < N; i += 2 * k) {
complex w(1, 0), x, y; // butterfly
for (int j = 0; j < k; j++) {
x = a[i + j], y = w * a[i + j + k];
a[i + j] = (x + y), a[i + j + k] = (x - y);
w = w * wn;
}
}
}
if (inv == -1)
for (int i = 0; i < N; i++) a[i].r = a[i].r / N;
}
signed main() {
// freopen("in.txt", "r", stdin);
// freopen("out.txt", "w", stdout);
scanf("%d%d", &n, &m); // 系数从低位到高位
for (int i = 0; i <= n; i++) scanf("%lf", &a[i].r);
for (int i = 0; i <= m; i++) scanf("%lf", &b[i].r);
N = 1, len = 0;
while (N <= n + m) N <<= 1, len++; // N = 2 ^ (log m + n)
for (int i = 0; i < N; i++) // 二进制位逆序置换 1100 -> 0011
rev[i] = (rev[i >> 1] >> 1) | ((i & 1) << (len - 1));
FFT(a, 1);
FFT(b, 1); // 不同函数值存储于系数中
for (int i = 0; i < N; i++) ans[i] = a[i] * b[i];
FFT(ans, -1); //逆变换得到系数
for (int i = 0; i <= n + m; i++) printf("%d ", (int)(ans[i].r + 0.5)); // eps
return 0;
}
NTT版本:
#include <iostream>
#define int long long
using namespace std;
const int maxn = 1e7 + 7;
const int mod = 998244353;
const int g = 3; // 原根g
int n, m;
int N, len;
int rev[maxn];
int a[maxn], b[maxn], ans[maxn];
int quickpow(int a, int n) {
int ans = 1;
while (n) {
if (n & 1) ans = ans * a % mod;
n >>= 1;
a = a * a % mod;
}
return ans;
}
int getinv(int a) { return quickpow(a, mod - 2); }
void NTT(int a[], int N, int inv) {
for (int i = 0; i < N; i++) // 0 1 2 3 4 5 6 7 -> 0 4 2 6 1 5 3 7
if (i < rev[i]) swap(a[i], a[rev[i]]);
for (int k = 1; k < N; k <<= 1) {
int wn = quickpow(g, (mod - 1) / (2 * k));
if (inv == -1) wn = getinv(wn);
for (int i = 0; i < N; i += 2 * k) {
int w = 1, x, y; // butterfly
for (int j = 0; j < k; j++) {
x = a[i + j], y = w * a[i + j + k] % mod;
a[i + j] = (x + y + mod) % mod, a[i + j + k] = (x - y + mod) % mod;
w = w * wn % mod;
}
}
}
if (inv == -1) {
int val = getinv(N);
for (int i = 0; i < N; i++) a[i] = a[i] * val % mod;
}
}
signed main() {
// freopen("in.txt", "r", stdin);
// freopen("out.txt", "w", stdout);
scanf("%d %d", &n, &m); // 系数从低位到高位
for (int i = 0; i <= n; i++) scanf("%d", &a[i]);
for (int i = 0; i <= m; i++) scanf("%d", &b[i]);
N = 1, len = 0;
while (N <= n + m) N <<= 1, len++; // N = 2 ^ (log m + n)
for (int i = 0; i < N; i++) // 二进制位逆序置换 1100 -> 0011
rev[i] = (rev[i >> 1] >> 1) | ((i & 1) << (len - 1));
NTT(a, N, 1);
NTT(b, N, 1); // 不同函数值存储于系数中
for (int i = 0; i < N; i++) ans[i] = a[i] * b[i];
NTT(ans, N, -1); //逆变换得到系数
for (int i = 0; i <= n + m; i++) printf("%lld ", (int)ans[i]); // eps
return 0;
}