Problem
对于一个多项式 a(x) a ( x ) ,求其逆元 b(x) b ( x ) ,即 a(x)∗b(x)≡1(modxn) a ( x ) ∗ b ( x ) ≡ 1 ( mod x n )
Solution
对于单个元素的逆元我们是会求的,比如说一个数 t t 的逆元在膜质数意义下为
但现在要求求一个多项式的逆元,联想到在模数为 x x 时可以快速求得其逆元为,可以考虑从这里开始递推
对于题目可设
a∗b≡1(modx2p)
a
∗
b
≡
1
(
mod
x
2
p
)
a∗c≡1(modxp)
a
∗
c
≡
1
(
mod
x
p
)
即已知 a,c a , c 求 b b
a∗c≡1(modxp)
a
∗
c
≡
1
(
mod
x
p
)
⇒b−c≡0(modxp) ⇒ b − c ≡ 0 ( mod x p )
⇒b2−2bc+c2≡0(modx2p)
⇒
b
2
−
2
b
c
+
c
2
≡
0
(
mod
x
2
p
)
同乘
a
a
考虑到
ab≡1(modx2p)
a
b
≡
1
(
mod
x
2
p
)
⇔b−2c+ac2≡0(modx2p)
⇔
b
−
2
c
+
a
c
2
≡
0
(
mod
x
2
p
)
⇔b≡2c−ac2(modx2p)
⇔
b
≡
2
c
−
a
c
2
(
mod
x
2
p
)
即得到了一个用 a,c a , c 表示 b b 的递推式
借助NTT的模数乘法,时间复杂度为
不过一开始觉得时间复杂度应该是 O(nlog22n) O ( n log 2 2 n )
后来发现复杂度应该是 O(nlogn+n2logn2+…)=O(nlog2n) O ( n log n + n 2 log n 2 + … ) = O ( n log 2 n )
发现自己还是数学思维太弱了
Code
#include<algorithm>
#include<cstdio>
#include<cctype>
using namespace std;
#define rg register
template <typename _Tp> inline _Tp read(_Tp&x){
rg char c11=getchar(),ob=0;x=0;
while(c11^'-'&&!isdigit(c11))c11=getchar();if(c11=='-')c11=getchar(),ob=1;
while(isdigit(c11))x=x*10+c11-'0',c11=getchar();if(ob)x=-x;return x;
}
const int N=2001000,G=3,p=998244353;
int a[N],b[N],c[N],rev[N];
inline int qpow(int A,int B){
int res(1);
while(B){
if(B&1)res=1ll*res*A%p;
A=1ll*A*A%p;B>>=1;
}return res;
}
void ntt(int*a,int n,int f){
for(rg int i=0;i<n;++i)if(i<rev[i])swap(a[i],a[rev[i]]);
for(rg int i=1;i<n;i<<=1){
int wn=qpow(G,(p-1)/(i<<1));
for(rg int j=0;j<n;j+=(i<<1)){
int w(1);
for(rg int k=0;k<i;++k,w=1ll*w*wn%p){
int x=a[j+k],y=1ll*w*a[j+k+i]%p;
a[j+k]=(x+y)%p,a[j+k+i]=(x-y+p)%p;
}
}
}
if(f==1)return ;
int tmp=qpow(n,p-2);reverse(a+1,a+n);
for(rg int i=0;i<n;++i)a[i]=1ll*a[i]*tmp%p;
}
void work(int d,int*a,int*b){
if(d==1){b[0]=qpow(a[0],p-2);return ;}
work((d+1)>>1,a,b);
int l(0),nn(1);
while(nn<(d<<1))nn<<=1,++l;
for(rg int i=1;i<nn;++i)rev[i]=(rev[i>>1]>>1)|((i&1)<<(l-1));
for(rg int i=0;i<d;++i)c[i]=a[i];
for(rg int i=d;i<nn;++i)c[i]=0;
ntt(c,nn,1);ntt(b,nn,1);
for(rg int i=0;i<nn;++i)
b[i]=1ll*(2-1ll*b[i]*c[i]%p+p)%p*b[i]%p;
ntt(b,nn,-1);
for(rg int i=d;i<nn;++i)b[i]=0;
return ;
}
int main(){
int n;read(n);
for(rg int i=0;i<n;++i)read(a[i]);
work(n,a,b);
for(rg int i=0;i<n;++i)printf("%d ",b[i]);
putchar('\n');return 0;
}