【模板】多项式乘法(FFT)
题目背景
这是一道多项式乘法模板题。
注意:本题并不属于中国计算机学会划定的提高组知识点考察范围。
题目描述
给定一个 n n n 次多项式 F ( x ) F(x) F(x),和一个 m m m 次多项式 G ( x ) G(x) G(x)。
请求出 F ( x ) F(x) F(x) 和 G ( x ) G(x) G(x) 的卷积。
输入格式
第一行两个整数 n , m n,m n,m。
接下来一行 n + 1 n+1 n+1 个数字,从低到高表示 F ( x ) F(x) F(x) 的系数。
接下来一行 m + 1 m+1 m+1 个数字,从低到高表示 G ( x ) G(x) G(x) 的系数。
输出格式
一行 n + m + 1 n+m+1 n+m+1 个数字,从低到高表示 F ( x ) ⋅ G ( x ) F(x) \cdot G(x) F(x)⋅G(x) 的系数。
样例 #1
样例输入 #1
1 2
1 2
1 2 1
样例输出 #1
1 4 5 2
提示
保证输入中的系数大于等于 0 0 0 且小于等于 9 9 9。
对于 100 % 100\% 100% 的数据: 1 ≤ n , m ≤ 10 6 1 \le n, m \leq {10}^6 1≤n,m≤106。
原题
代码
#include <bits/stdc++.h>
using namespace std;
typedef long long ll;
// NTT的特殊要求:模数P需满足P=k*2^m+1,例如常用的模数998244353=7*17*2^23+1
const int Mod = 998244353;
const int G = 3; // 998244353的原根为3
const int maxn = 1e6 + 6;
int QuickPow(int a, int k) // 快速幂
{
int ret = 1;
while (k)
{
if (k & 1)
ret = 1LL * ret * a % Mod;
a = 1LL * a * a % Mod;
k >>= 1;
}
return ret;
}
int rev[3 * maxn];
inline void GetReverse(int len, int lg) // 获得二进制位的反转
{
for (int i = 0; i < len; i++)
rev[i] = 0;
for (int i = 0; i < len; i++)
rev[i] = (rev[i >> 1] >> 1) | ((i & 1) << (lg - 1));
}
void NTT(vector<int> &a, int dir)
{
int len = a.size();
for (int i = 0; i < len; i++)
if (i < rev[i])
swap(a[i], a[rev[i]]);
int g = (dir == 1 ? G : QuickPow(G, Mod - 2));
for (int stp = 1; stp < len; stp <<= 1)
{
int wn = QuickPow(g, (Mod - 1) / (stp << 1));
int w = 1;
for (int k = 0; k < stp; k++)
{
for (int even = k; even < len; even += stp << 1)
{
int odd = even + stp;
int tmp = 1LL * w * a[odd] % Mod;
a[odd] = (a[even] - tmp + Mod) % Mod;
a[even] = (ll)(a[even] + tmp) % Mod;
}
w = 1LL * w * wn % Mod;
}
}
if (dir == -1)
{
int inv = QuickPow(len, Mod - 2); // 乘法逆元
for (int i = 0; i < len; i++)
a[i] = 1LL * a[i] * inv % Mod;
}
}
int main()
{
ios::sync_with_stdio(0);
cin.tie(0);
cout.tie(0);
// 以下代码实现了多项式f(x)和多项式g(x)相乘
int n, m;
cin >> n >> m;
int lg = 0;
int bound = 1; // 将多项式的系数扩充到2的整数次幂
while (bound <= n + m)
{
bound <<= 1;
lg++;
}
GetReverse(bound, lg);
vector<int> f(bound, 0), g(bound, 0); // 数组大小开到bound(即结果多项式扩充后的大小)
for (int i = 0; i <= n; i++)
{
cin >> f[i];
}
for (int i = 0; i <= m; i++)
{
cin >> g[i];
}
NTT(f, 1);
NTT(g, 1);
for (int i = 0; i < bound; i++)
{
f[i] = (1LL * f[i] * g[i]) % Mod; // 将相乘结果存储在f[]数组中
}
NTT(f, -1);
for (int i = 0; i <= n + m; i++)
{
cout << f[i] << ' ';
}
return 0;
}