介绍
准确来说,是求多项式的逆元。
对于多项式 A ( x ) A(x) A(x),如果存在 A ( x ) B ( x ) ≡ 1 ( m o d x n ) A(x)B(x)\equiv 1 \pmod {x^n} A(x)B(x)≡1(modxn),那么称 B ( x ) B(x) B(x) 为 A ( x ) A(x) A(x) 在模 x n x^n xn 意义下的逆元。
注意,这里的模 x n x^n xn,是指舍弃含有 x n x^n xn 及更高次的项。
比如说有一个多项式: 5 + x + 6 x 2 + 3 x 3 + 2 x 4 5+x+6x^2+3x^3+2x^4 5+x+6x2+3x3+2x4,这个多项式对 x 3 x^3 x3 取模后,就是 5 + x + 6 x 2 5+x+6x^2 5+x+6x2。
正题
假如 A ( x ) B ( x ) ≡ 1 ( m o d x n ) A(x)B(x)\equiv 1 \pmod {x^n} A(x)B(x)≡1(modxn),那么也就是说, A ( x ) A(x) A(x) 和 B ( x ) B(x) B(x) 在模 x n x^n xn 意义下相乘得到的多项式,只有常数项为 1 1 1,其他项的系数都为 0 0 0。
引理 如果满足 A ( x ) B ( x ) ≡ 1 ( m o d x n ) A(x)B(x)\equiv 1 \pmod {x^n} A(x)B(x)≡1(modxn),那么肯定也满足 A ( x ) B ( x ) ≡ 1 ( m o d x m ) ( 1 ≤ m ≤ n ) A(x)B(x)\equiv 1 \pmod {x^m}~~(1\leq m \leq n) A(x)B(x)≡1(modxm) (1≤m≤n),这是下面扯淡的一个大前提,所以稍稍证明一下还是有必要的。
我们设 C ( x ) = A ( x ) B ( x ) ( m o d x n ) C(x)=A(x)B(x) \pmod{x^n} C(x)=A(x)B(x)(modxn),那么有: C i = ∑ j = 0 i A j B i − j ( m o d x n ) C_i=\sum_{j=0}^i A_jB_{i-j} \pmod {x^n} Ci=∑j=0iAjBi−j(modxn),也就是说, C i C_i Ci 只跟 A ( x ) , B ( x ) A(x),B(x) A(x),B(x) 的第 j ( j ≤ i ) j~~(j\leq i) j (j≤i) 项有关,那么在舍弃掉第 i i i 项以上的项时,并不会对 C i C_i Ci 产生影响。
由此可知,当舍弃掉 x m x^m xm 及以上的项时,并不会影响 C m − 1 C_{m-1} Cm−1 及以下的项。
于是 A ( x ) B ( x ) ≡ 1 ( m o d x m ) ( 1 ≤ m ≤ n ) A(x)B(x) \equiv 1 \pmod{x^m}~~(1\leq m \leq n) A(x)B(x)≡1(modxm) (1≤m≤n) 成立。
于是我们就可以愉快地推柿子了:
设 B ( x ) B(x) B(x) 为 A ( x ) A(x) A(x) 在模 x n x^n xn 意义下的逆元, G ( x ) G(x) G(x) 为 A ( x ) A(x) A(x) 的在模 x ⌈ n 2 ⌉ x^{\lceil \frac n 2 \rceil} x⌈2n⌉ 意义下的逆元(事实上由上面的引理可知, G G G 也就是 B B B 模 x ⌈ n 2 ⌉ x^{\lceil \frac n 2 \rceil} x⌈2n⌉ 后得到的多项式)。
那么有
G
(
x
)
≡
B
(
x
)
(
m
o
d
x
⌈
n
2
⌉
)
B
(
x
)
−
G
(
x
)
≡
0
(
m
o
d
x
⌈
n
2
⌉
)
(
B
(
x
)
−
G
(
x
)
)
2
≡
0
(
m
o
d
x
n
)
B
(
x
)
2
−
2
B
(
x
)
G
(
x
)
+
G
(
x
)
2
≡
0
(
m
o
d
x
n
)
G(x)\equiv B(x) \pmod {x^{\lceil \frac n 2 \rceil}}\\ B(x)-G(x)\equiv 0 \pmod {x^{\lceil \frac n 2 \rceil}}\\ (B(x)-G(x))^2\equiv 0 \pmod {x^n}\\ B(x)^2-2B(x)G(x)+G(x)^2\equiv 0 \pmod {x^n}\\
G(x)≡B(x)(modx⌈2n⌉)B(x)−G(x)≡0(modx⌈2n⌉)(B(x)−G(x))2≡0(modxn)B(x)2−2B(x)G(x)+G(x)2≡0(modxn)
让整个柿子乘上
A
(
x
)
A(x)
A(x),那么有:
A
(
x
)
B
(
x
)
B
(
x
)
−
2
A
(
x
)
B
(
x
)
G
(
x
)
+
A
(
x
)
G
(
x
)
2
≡
0
(
m
o
d
x
n
)
A(x)B(x)B(x)-2A(x)B(x)G(x)+A(x)G(x)^2\equiv 0 \pmod {x^n}\\
A(x)B(x)B(x)−2A(x)B(x)G(x)+A(x)G(x)2≡0(modxn)
因为
A
(
x
)
B
(
x
)
≡
1
(
m
o
d
x
n
)
A(x)B(x)\equiv 1 \pmod {x^n}
A(x)B(x)≡1(modxn),所以有
B
(
x
)
−
2
G
(
x
)
+
A
(
x
)
G
(
x
)
2
≡
0
(
m
o
d
x
n
)
B
(
x
)
≡
2
G
(
x
)
−
A
(
x
)
G
(
x
)
2
(
m
o
d
x
n
)
B
(
x
)
≡
G
(
x
)
(
2
−
A
(
x
)
G
(
x
)
)
(
m
o
d
x
n
)
B(x)-2G(x)+A(x)G(x)^2\equiv 0 \pmod {x^n}\\ B(x)\equiv 2G(x)-A(x)G(x)^2 \pmod {x^n}\\ B(x)\equiv G(x)(2-A(x)G(x)) \pmod {x^n}\\
B(x)−2G(x)+A(x)G(x)2≡0(modxn)B(x)≡2G(x)−A(x)G(x)2(modxn)B(x)≡G(x)(2−A(x)G(x))(modxn)
这说明,我们可以通过 G ( x ) G(x) G(x) 求得 B ( x ) B(x) B(x)。
然后像倍增一样搞,从小往大推即可。
对于里面 A ( x ) G ( x ) A(x)G(x) A(x)G(x) 这个部分,因为答案要求对 998244353 998244353 998244353 取模,所以用 N T T NTT NTT 搞一搞即可。
代码如下:
#include <cstdio>
#include <cmath>
#include <algorithm>
using namespace std;
#define maxn 300010
#define ll long long
#define mod 998244353
int n;
ll ksm(ll x,ll y)
{
ll re=1,tot=x;
while(y)
{
if(y&1)re=re*tot%mod;
tot=tot*tot%mod;
y>>=1;
}
return re;
}
#define inv(x) ksm(x,mod-2)
int up,l,r[maxn];
void work(int len)
{
for(l=0,up=1;up<=len;up<<=1,l++);
for(int i=1;i<up;i++)
r[i]=(r[i>>1]>>1)|((i&1)<<(l-1));
}
ll a[maxn],b[maxn];
const int G=3,invG=inv(G);
void ntt(ll *f,int len,int type)
{
for(int i=1;i<len;i++)
if(i<r[i])swap(f[i],f[r[i]]);
for(int mid=1;mid<len;mid<<=1)
{
ll wn=ksm((type==1?G:invG),(mod-1)/mid>>1);
for(int block=mid<<1,j=0;j<len;j+=block)
{
ll w=1;
for(int i=j;i<j+mid;i++,w=w*wn%mod)
{
int x=f[i],y=f[i+mid]*w%mod;
f[i]=(x+y)%mod;f[i+mid]=(x-y+mod)%mod;
}
}
}
}
void solve(int len,ll *f,ll *g)
{
if(len==1){g[0]=inv(f[0]);return;}
solve((len+1)>>1,f,g);
work(len+n);
for(int i=0;i<up;i++)
a[i]=f[i],b[i]=(i<(len+1)>>1?g[i]:0);
ntt(a,up,1);ntt(b,up,1);
for(int i=0;i<up;i++)
a[i]=b[i]*((2-a[i]*b[i]%mod+mod)%mod)%mod;
ntt(a,up,-1);
int invup=inv(up);
for(int i=0;i<len;i++)
g[i]=a[i]*invup%mod;
}
ll f[maxn],g[maxn];
int main()
{
scanf("%d",&n);
for(int i=0;i<n;i++)
scanf("%lld",&f[i]);
solve(n,f,g);
for(int i=0;i<n;i++)
printf("%lld ",g[i]);
}