好题,证明的过程真的很妙。
正文部分:
题意:
有两个多项式\(A(x)\)与\(B(x)\),满足\(A(x)*B(x)≡1(modx^n)\)给出\(A(x)\),求\(B(x)\)
设\(A(x)*B'(x)≡1(modx^{ceil{n/2}})\)
观察得知,这两者可相减,于是得到这个式子:
\(B(x)-B'(x)≡0(modx^{ceiln/2})\)
将两边平方得到:
\(B(x)^2+B'(x)^2-2BB'(x)≡0(modx^{n})\)
不妨都乘\(A\),则:
\(B(x) +AB'(x)^2-2B'(x)≡0(modx^n)\)
移项得知:
\(B(x)=2B'(x)-AB'(x)^2\)
同时提\(B'\)
则:\(B(x)=B'(x)(2-AB'(x))\)
于是我们就找到了递归式,从上往下递归即可.
My Code
#include <bits/stdc++.h>
#define il inline
#define gc getchar
#define pc putchar
typedef int mi;
#define int long long
typedef long long LL;
const int MAXN = 1e6 + 10;
const LL p = 998244353;
using namespace std;
int n,m,i,j,k;
int R[MAXN],a[MAXN],b[MAXN],wn[MAXN];
namespace IO {
il int read() {
int res = 0;char c;bool sign = 0;
for(c = gc();!isdigit(c);c = gc()) sign |= c == '-';
for(;isdigit(c);c = gc()) res = (res << 1) + (res << 3) + (c ^ 48);
return sign ? -res : res;
}
}
using IO::read;
il int fpow(int a,int b,int mod = p) {
int res = 1;
while(b) {
if(b & 1) res = (res * a) % mod;
b >>= 1;a = (a * a) % mod;
}
return res;
}
il void ntt(LL a[],int len,int on) {
for(int i = 0, j = 0; i < len; ++i) {
if (i > j) swap(a[i], a[j]);
for (int l = len >> 1; (j ^= l) < l; l >>= 1);
}
int id = 0;
for(int i = 1;i < len;i <<= 1) {
id++;
for(int j = 0;j < len;j += i << 1) {
LL w = 1;
for(int k = 0;k < i;k++) {
int x = a[j + k] % p;
int y = w * a[j + k + i] % p;
a[j + k] = (x + y) % p;
a[j + k + i] = (x - y + p) % p;
w = w * wn[id] % p;
}
}
}
if(!~on) {
reverse(a + 1,a + len);
int inv = fpow(len,p - 2,p);
for(int i = 0;i < len;i++) a[i] = a[i] * inv % p;
}
return;
}
void calc(int deg,LL *a,LL *b) {
static LL tmp[MAXN];
if(deg == 1) b[0] = fpow(a[0],p - 2);
else {
calc((deg + 1) >> 1,a,b);
int _p = 1;
while(_p < deg << 1) _p <<= 1;
copy(a,a + deg,tmp);
fill(tmp + deg,tmp + _p,0);
ntt(tmp,_p,1);ntt(b,_p,1);
for(int i = 0;i < _p;i++) {
b[i] = (2 - b[i] * tmp[i] % p + p) % p * b[i] % p;
}
ntt(b,_p,-1);fill(b + deg,b + _p,0);
}
return;
}
mi main() {
for(int i = 0;i < 21;i++) wn[i] = fpow(3,(p - 1) / (1 << i),p);
n = read();
for(int i = 0;i < n;i++) a[i] = read();
int li = 1,t = 0;while(li < n << 1) li <<= 1,t++;
calc(n,a,b);
for(int i = 0;i < n;i++) printf("%lld ",(b[i] + p) % p);
return 0;
}