多项式求逆
f ( x ) g ( x ) ≡ 1 ( m o d x n ) f(x) g(x) \equiv 1 \pmod {x ^ n} f(x)g(x)≡1(modxn),称 f ( x ) f(x) f(x)为 g ( x ) g(x) g(x)或者 g ( x ) g(x) g(x)为 f ( x ) f(x) f(x)膜 x n x ^ n xn意义下的逆元。
下面我们讨论给定 f ( x ) f(x) f(x),求其逆 f − 1 ( x ) f ^{-1}(x) f−1(x)。
倍增求解
假设我们已经求得 f ( x ) f(x) f(x)膜 x ⌈ n 2 ⌉ x ^{\lceil \frac{n}{2}} \rceil x⌈2n⌉下的逆元 f 0 − 1 ( x ) f_0 ^{-1} (x) f0−1(x),要求 f − 1 ( x ) f ^{-1}(x) f−1(x),即膜 x n x ^{n} xn下的逆元,则
f ( x ) f 0 − 1 ( x ) ≡ 1 ( m o d x ⌈ n 2 ⌉ ) f(x) f_0 ^{-1}(x) \equiv 1 \pmod{x ^{\lceil\frac{n}{2}\rceil} } f(x)f0−1(x)≡1(modx⌈2n⌉)
显然 f ( x ) f − 1 ( x ) ≡ 1 ( m o d x ⌈ n 2 ⌉ ) f(x) f^{-1}(x) \equiv 1 \pmod {x ^{ \lceil\frac{n}{2}\rceil}} f(x)f−1(x)≡1(modx⌈2n⌉)也是成立的
对两边同时乘以 f 0 − 1 ( x ) f_0 ^{-1}(x) f0−1(x)并移项有
f − 1 ( x ) − f 0 − 1 ( x ) ≡ 0 ( m o d x ⌈ n 2 ⌉ ) f ^{-1}(x) - f_0 ^{-1}(x) \equiv 0 \pmod{x ^{\lceil\frac{n}{2}\rceil}} f−1(x)−f0−1(x)≡0(modx⌈2n⌉)
对两边同时开方得到
f − 2 ( x ) − 2 f − 1 f 0 − 1 ( x ) + f 0 − 2 ( x ) ≡ 0 ( m o d x n ) f ^{-2}(x) - 2 f^{-1} f_0 ^{-1}(x) + f_0 ^{-2}(x) \equiv 0 \pmod {x ^n} f−2(x)−2f−1f0−1(x)+f0−2(x)≡0(modxn)
我们再对两边乘上一个 f ( x ) f(x) f(x),则有
f − 1 ( x ) − 2 f 0 − 1 + f ( x ) f 0 − 2 ( x ) ≡ 0 ( m o d x n ) f ^{-1}(x) - 2 f_0 ^{-1} + f(x) f_0 ^{-2}(x) \equiv 0 \pmod{x ^n} f−1(x)−2f0−1+f(x)f0−2(x)≡0(modxn)
再对其进行移项可得
f − 1 ( x ) ≡ f 0 − 1 ( x ) ( 2 − f ( x ) f 0 − 1 ( x ) ) ( m o d x n ) f ^{-1}(x) \equiv f_0 ^{-1}(x)\left( 2 - f(x) f_0 ^{-1}(x) \right) \pmod {x ^n} f−1(x)≡f0−1(x)(2−f(x)f0−1(x))(modxn)
由此我们递归求解即可。
#include <bits/stdc++.h>
using namespace std;
typedef long long ll;
const int N = 5e6 + 10, mod = 998244353;
int a[N], b[N], c[N], r[N];
int quick_pow(int a, int n) {
int ans = 1;
while (n) {
if (n & 1) {
ans = 1ll * ans * a % mod;
}
a = 1ll * a * a % mod;
n >>= 1;
}
return ans;
}
void get_r(int lim) {
for (int i = 0; i < lim; i++) {
r[i] = (i & 1) * (lim >> 1) + (r[i >> 1] >> 1);
}
}
void NTT(int *f, int lim, int rev) {
for (int i = 0; i < lim; i++) {
if (i < r[i]) {
swap(f[i], f[r[i]]);
}
}
for (int mid = 1; mid < lim; mid <<= 1) {
int wn = quick_pow(3, (mod - 1) / (mid << 1));
for (int len = mid << 1, cur = 0; cur < lim; cur += len) {
int w = 1;
for (int k = 0; k < mid; k++, w = 1ll * w * wn % mod) {
int x = f[cur + k], y = 1ll * w * f[cur + mid + k] % mod;
f[cur + k] = (x + y) % mod, f[cur + mid + k] = (x - y + mod) % mod;
}
}
}
if (rev == -1) {
int inv = quick_pow(lim, mod - 2);
reverse(f + 1, f + lim);
for (int i = 0; i < lim; i++) {
f[i] = 1ll * f[i] * inv % mod;
}
}
}
void polyinv(int *a, int *b, int n) {
if (n == 1) {
b[0] = quick_pow(a[0], mod - 2);
return ;
}
polyinv(a, b, n + 1 >> 1);
int lim = 1;
while (lim < 2 * n) {
lim <<= 1;
}
get_r(lim);
for (int i = 0; i < n; i++) {
c[i] = a[i];
}
for (int i = n; i < lim; i++) {
c[i] = 0;
}
NTT(b, lim, 1);
NTT(c, lim, 1);
for (int i = 0; i < lim; i++) {
int cur = (2 - 1ll * c[i] * b[i] % mod + mod) % mod;
b[i] = 1ll * b[i] * cur % mod;
}
NTT(b, lim, -1);
for (int i = n; i < lim; i++) {
b[i] = 0;
}
}
int main() {
// freopen("in.txt", "r", stdin);
// freopen("out.txt", "w", stdout);
// ios::sync_with_stdio(false), cin.tie(0), cout.tie(0);
int n;
scanf("%d", &n);
for (int i = 0; i < n; i++) {
scanf("%d", &a[i]);
}
polyinv(a, b, n);
for (int i = 0; i < n; i++) {
printf("%d%c", b[i], i + 1 == n ? '\n' : ' ');
}
return 0;
}