例题:洛谷P4238
什么是多项式的逆呢?嗯,就是如例题题面描述的那样,求满足
F(x)G(x)≡1(modxn)
F
(
x
)
G
(
x
)
≡
1
(
mod
x
n
)
的G。
如果多项式F只有一项,那么显然
G0
G
0
就是
F0
F
0
的逆元。
若有n项,递归求解。
假如我们已知
F(x)H(x)≡1(modx⌈n2⌉)
F
(
x
)
H
(
x
)
≡
1
(
mod
x
⌈
n
2
⌉
)
又显然
F(x)G(x)≡1(modx⌈n2⌉)
F
(
x
)
G
(
x
)
≡
1
(
mod
x
⌈
n
2
⌉
)
那么
F(x)(G(x)−H(x))≡0(modx⌈n2⌉)
F
(
x
)
(
G
(
x
)
−
H
(
x
)
)
≡
0
(
mod
x
⌈
n
2
⌉
)
即
G(x)−H(x)≡0(modx⌈n2⌉)
G
(
x
)
−
H
(
x
)
≡
0
(
mod
x
⌈
n
2
⌉
)
两边同时平方。由于
G(x)−H(x)
G
(
x
)
−
H
(
x
)
在模
x⌈n2⌉
x
⌈
n
2
⌉
为0,则其0次项到
⌈n2⌉−1
⌈
n
2
⌉
−
1
次项都为0。平方后的多项式记为P,则
Pi=∑ij=0(G(x)−H(x))j(G(x)−H(x))i−j
P
i
=
∑
j
=
0
i
(
G
(
x
)
−
H
(
x
)
)
j
(
G
(
x
)
−
H
(
x
)
)
i
−
j
,显然
(G(x)−H(x))j
(
G
(
x
)
−
H
(
x
)
)
j
和
(G(x)−H(x))i−j
(
G
(
x
)
−
H
(
x
)
)
i
−
j
至少有一项的次数小于
⌈n2⌉
⌈
n
2
⌉
,为0,所以:
G(x)2+H(x)2−2G(x)H(x)≡0(modxn)
G
(
x
)
2
+
H
(
x
)
2
−
2
G
(
x
)
H
(
x
)
≡
0
(
mod
x
n
)
两边同时乘F(x),再由
F(x)G(x)≡1(modxn)
F
(
x
)
G
(
x
)
≡
1
(
mod
x
n
)
可得:
G(x)≡2H(x)−F(x)H(x)2(modxn)
G
(
x
)
≡
2
H
(
x
)
−
F
(
x
)
H
(
x
)
2
(
mod
x
n
)
用NTT来做多项式乘法即可解决本题。
时间复杂度是
O(nlogn)
O
(
n
l
o
g
n
)
的
#include<bits/stdc++.h>
using namespace std;
int read() {
int q=0;char ch=' ';
while(ch<'0'||ch>'9') ch=getchar();
while(ch>='0'&&ch<='9') q=q*10+ch-'0',ch=getchar();
return q;
}
#define RI register int
const int mod=998244353,G=3,N=2100000;
int n;
int a[N],b[N],c[N],rev[N];
int ksm(int x,int y) {
int re=1;
for(;y;y>>=1,x=1LL*x*x%mod) if(y&1) re=1LL*re*x%mod;
return re;
}
void NTT(int *a,int n,int x) {
for(RI i=0;i<n;++i) if(i<rev[i]) swap(a[i],a[rev[i]]);
for(RI i=1;i<n;i<<=1) {
RI gn=ksm(G,(mod-1)/(i<<1));
for(RI j=0;j<n;j+=(i<<1)) {
RI t1,t2,g=1;
for(RI k=0;k<i;++k,g=1LL*g*gn%mod) {
t1=a[j+k],t2=1LL*g*a[j+k+i]%mod;
a[j+k]=(t1+t2)%mod,a[j+k+i]=(t1-t2+mod)%mod;
}
}
}
if(x==1) return;
int ny=ksm(n,mod-2); reverse(a+1,a+n);
for(RI i=0;i<n;++i) a[i]=1LL*a[i]*ny%mod;
}
void work(int deg,int *a,int *b) {
if(deg==1) {b[0]=ksm(a[0],mod-2);return;}
work((deg+1)>>1,a,b);
RI len=0,orz=1;
while(orz<(deg<<1)) orz<<=1,++len;
for(RI i=1;i<orz;++i) rev[i]=(rev[i>>1]>>1)|((i&1)<<(len-1));
for(RI i=0;i<deg;++i) c[i]=a[i];
for(RI i=deg;i<orz;++i) c[i]=0;
NTT(c,orz,1),NTT(b,orz,1);
for(RI i=0;i<orz;++i)
b[i]=1LL*(2-1LL*c[i]*b[i]%mod+mod)%mod*b[i]%mod;
NTT(b,orz,-1);
for(RI i=deg;i<orz;++i) b[i]=0;
}
int main()
{
n=read();
for(RI i=0;i<n;++i) a[i]=read();
work(n,a,b);
for(RI i=0;i<n;++i) printf("%d ",b[i]);
return 0;
}