https://www.luogu.org/problem/P4721
这道题不算是一道裸的多项式求逆模板。
首先这道题题目说的是分治FFT,我反正弄了几天,开始学FFT就看到这道题了,没有弄出来,到处瞎想,太弱了。
随着知识的深入后来发现这是一道模板题。
其实这道题和生成函数应该不是很大,即使你不知道,我想可能乱搞也会出来。。。。。
我们来开始推导一下。
首先有这个式子:
然后我们知道一个序列可以表示成多项式。我们设生成函数F(x)和G(x):
我们对于没有的地方全部规定为零。(和信号与系统的离散信号那套差不多)
这里g[0]没有说明,所以为0,其他没有说明的也为零。
我们开始做卷积:
我们在变一下:
我们发现后面的一项很熟悉,因为g[0]等于0,当k=0时后面的一部分等于0,当k>0时,就等于f[k].
卷积就直接等于:
和F(x)就差了一项f(0)x^0=f(0),又因为f(0)等于1所以在变一下:
就是多项式求逆了。
下面稍微联想一下信号与系统上关于离散卷积的运算:
根据离散卷积的公式:
吧g和f当成由于这都是正时间轴上的序列,其他时域不存在信号因此其他算进去了也无意义,因此带入上面的公式新的序列
因此这样就轻松得到了上面的卷积答案,并且还知道了,新序列y的开始地方就是g和h开始的坐标相加。
其实我就是这样推的。。。。。根本没有用生成函数(电子专业)。
剩下的就是多项式求逆,就是一个模板,也帖一个模板方便自己。
#include "bits/stdc++.h"
using namespace std;
const double eps = 1e-6;
#define reg register
#define lowbit(x) x&-x
#define pll pair<ll,ll>
#define pii pair<int,int>
#define fi first
#define se second
#define makp make_pair
#define cp complex<double>
int dcmp(double x) {
if (fabs(x) < eps) return 0;
return (x > 0) ? 1 : -1;
}
typedef long long ll;
typedef unsigned long long ull;
const ull hash1 = 201326611;
const ull hash2 = 50331653;
const ll N = 280000 + 10;
const int M = 1000000;
const int inf = 0x3f3f3f3f;
const ll mod = 998244353;
const double PI = acos(-1.0);
ll Mod(ll x) {
if (x >= mod) x -= mod;
return x;
}
ll quick(ll a, ll n) {
ll ans = 1;
while (n) {
if (n & 1) ans = ans * a % mod;
a = a * a % mod;
n >>= 1;
}
return ans;
}
ll r[N], g, tot, lim;
void init(int len) {
tot = 1, lim = 0;
while (tot < 2 * len) tot <<= 1, lim++;
for (int i = 0; i < tot; i++) {
r[i] = (r[i >> 1] >> 1) | ((i & 1) << (lim - 1));
}
}
void ntt(ll *a, int tot, int inv) {
for (int i = 0; i < tot; i++) {
if (i < r[i]) swap(a[i], a[r[i]]);
}
for (int l = 2; l <= tot; l <<= 1) {
ll tmp = quick(g, (mod - 1) / l);
if (inv) tmp = quick(tmp, mod - 2);
int m = l / 2;
for (int j = 0; j < tot; j += l) {
ll w = 1;
for (int i = 0; i < m; i++) {
ll t = 1LL * a[j + i + m] * w % mod;
a[j + i + m] = Mod(a[j + i] - t + mod);
a[j + i] = Mod(a[j + i] + t);
w = 1LL * w * tmp % mod;
}
}
}
if (inv) {
ll t = quick(tot, mod - 2);
for (int i = 0; i < tot; i++) {
a[i] = 1LL * a[i] * t % mod;
}
}
}
int n;
ll a[N], b[N], c[N];
void solve(int len, ll *a, ll *b) {
if (len == 1) {
b[0] = quick(a[0], mod - 2);
return;
}
solve((len + 1) >> 1, a, b);
init(len);
for (int i = 0; i < len; i++) c[i] = a[i];
for (int i = len; i < tot; i++) c[i] = 0;
ntt(c, tot, 0);
ntt(b, tot, 0);
for (int i = 0; i < tot; i++) {
b[i] = 1LL * (2 - 1LL * c[i] * b[i] % mod + mod) % mod * b[i] % mod;
}
ntt(b, tot, 1);
for (int i = len; i < tot; i++) b[i] = 0;
}
int main() {
scanf("%d", &n);
for (int i = 1; i < n; i++) {
scanf("%lld", &a[i]);
a[i] = -a[i];
}
a[0] = 1;
g = 3;
solve(n, a, b);
for (int i = 0; i < n; i++)
printf("%lld ", b[i]);
return 0;
}