传送门
多项式求逆模板题。
简单讲讲?
多项式求逆
- 定义:
对于一个多项式 A ( x ) A(x) A(x),如果存在一个多项式 B ( x ) B(x) B(x),满足 B ( x ) B(x) B(x)的次数小于等于 A ( x ) A(x) A(x)且 A ( x ) B ( x ) ≡ 1 m o d    x n A(x)B(x)≡1 \mod x^n A(x)B(x)≡1modxn,那么我们称B(x)为 A ( x ) A(x) A(x)在模 x n x^n xn意义下的逆元,简单记作 A − 1 ( x ) A^{−1}(x) A−1(x) - 求法:
n = 1 ? n=1? n=1?那不就是 c c c的逆元么。
n > 1 ? n>1? n>1?我们令 B ( x ) = A − 1 ( x ) B(x)=A^{-1}(x) B(x)=A−1(x)
那么有 A ( x ) B ( x ) ≡ 1 m o d    x n A(x)B(x)\equiv 1 \mod x^n A(x)B(x)≡1modxn
然后可以用类似倍增的方法求。
假设我们已经知道 C ( x ) C(x) C(x)满足 A ( x ) C ( x ) ≡ 1 m o d    x n 2 A(x)C(x)\equiv 1\mod x^{\frac n2} A(x)C(x)≡1modx2n(这里的 n 2 \frac n2 2n都是向上取整)
显然 A ( x ) B ( x ) ≡ 1 m o d    x n 2 A(x)B(x)\equiv 1\mod x^{\frac n2} A(x)B(x)≡1modx2n是成立的。
我们将两式相减:
A ( x ) ( B ( x ) − C ( x ) ) ≡ 0 m o d    x n 2 A(x)(B(x)-C(x))\equiv 0\mod x^{\frac n2} A(x)(B(x)−C(x))≡0modx2n
所以 B ( x ) − C ( x ) ≡ 0 m o d    x n 2 B(x)-C(x)\equiv 0\mod x^{\frac n2} B(x)−C(x)≡0modx2n
然后将两边平方:
B 2 ( x ) − 2 B ( x ) C ( x ) + C 2 ( x ) ≡ 0 m o d    x n 2 B^2(x)-2B(x)C(x)+C^2(x)\equiv 0\mod x^{\frac n2} B2(x)−2B(x)C(x)+C2(x)≡0modx2n
=> B 2 ( x ) − 2 B ( x ) C ( x ) + C 2 ( x ) ≡ 0 m o d    x n B^2(x)-2B(x)C(x)+C^2(x)\equiv 0\mod x^n B2(x)−2B(x)C(x)+C2(x)≡0modxn
这一步很关键,请神犇们仔细思考原因
然后两边同时乘上 A ( x ) A(x) A(x)
=> B ( x ) − 2 C ( x ) + A ( x ) C 2 ( x ) ≡ 0 m o d    x n ) B(x)-2C(x)+A(x)C^2(x)\equiv 0\mod x^n) B(x)−2C(x)+A(x)C2(x)≡0modxn)
于是 B ( x ) ≡ 2 C ( x ) − A ( x ) C 2 ( x ) m o d    x n B(x)\equiv2C(x)-A(x)C^2(x)\mod x^n B(x)≡2C(x)−A(x)C2(x)modxn
乘法可以用 f f t / n t t fft/ntt fft/ntt加速,因为每次递归的时候多项式最高次项都减少一半,所以总复杂度仍然是 O ( n l o g n ) O(nlogn) O(nlogn)
代码:
#include<bits/stdc++.h>
#define ri register int
using namespace std;
inline int read(){
int ans=0;
char ch=getchar();
while(!isdigit(ch))ch=getchar();
while(isdigit(ch))ans=(ans<<3)+(ans<<1)+(ch^48),ch=getchar();
return ans;
}
typedef long long ll;
const int mod=998244353;
int n;
inline int ksm(int a,int p){int ret=1;for(;p;p>>=1,a=(ll)a*a%mod)if(p&1)ret=(ll)ret*a%mod;return ret;}
inline int add(const int&a,const int&b){return a+b>=mod?a+b-mod:a+b;}
inline int dec(const int&a,const int&b){return a>=b?a-b:a-b+mod;}
inline int mul(const int&a,const int&b){return (ll)a*b%mod;}
int lim,tim;
vector<int>pos,A,B;
inline void init(const int&n){lim=1,tim=0;
while(lim<=n)lim<<=1,++tim;
pos.resize(lim),pos[0]=0;
for(ri i=0;i<lim;++i)pos[i]=(pos[i>>1]>>1)|((i&1)<<(tim-1));
}
inline void ntt(vector<int>&a,int type){
for(ri i=0;i<lim;++i)if(i<pos[i])swap(a[i],a[pos[i]]);
for(ri mult=(mod-1)/2,mid=1,wn,typ=type==1?3:(mod+1)/3;mid<lim;mid<<=1,mult>>=1){
wn=ksm(typ,mult);
for(ri j=0,len=mid<<1,w;j<lim;j+=len){
w=1;
for(ri k=0,a0,a1;k<mid;w=mul(w,wn),++k){
a0=a[j+k],a1=mul(w,a[j+k+mid]);
a[j+k]=add(a0,a1),a[j+k+mid]=dec(a0,a1);
}
}
}
if(type==-1)for(ri i=0,inv=ksm(lim,mod-2);i<lim;++i)a[i]=mul(a[i],inv);
}
struct poly{
vector<int>a;
poly(const int&n,int x=0){a.resize(n+1),a[n]=x;}
inline int&operator[](const int&i){return a[i];}
inline const int&operator[](const int&i)const{return a[i];}
inline poly extend(const int&x){poly ret=*this;return ret.a.resize(x),ret;}
inline int deg()const{return a.size()-1;}
friend inline poly operator+(const poly&a,const poly&b){
poly c(max(a.deg(),b.deg()));
for(ri i=0;i<=a.deg();++i)c[i]=add(c[i],a[i]);
for(ri i=0;i<=b.deg();++i)c[i]=add(c[i],b[i]);
return c;
}
friend inline poly operator-(const poly&a,const poly&b){
poly c(max(a.deg(),b.deg()));
for(ri i=0;i<=a.deg();++i)c[i]=add(c[i],a[i]);
for(ri i=0;i<=b.deg();++i)c[i]=dec(c[i],b[i]);
return c;
}
friend inline poly operator*(const poly&a,const int&b){
poly c=a;
for(ri i=0;i<=c.deg();++i)c[i]=mul(b,c[i]);
return c;
}
friend inline poly operator *(const poly&a,const poly&b){
int n=a.deg(),m=b.deg();
init(n+m),A.resize(lim),B.resize(lim);
poly ret(lim-1);
for(ri i=0;i<=n;++i)A[i]=a[i];
for(ri i=0;i<=m;++i)B[i]=b[i];
for(ri i=n+1;i<lim;++i)A[i]=0;
for(ri i=m+1;i<lim;++i)B[i]=0;
ntt(A,1),ntt(B,1);
for(ri i=0;i<lim;++i)A[i]=mul(A[i],B[i]);
return ntt(A,-1),ret.a=A,ret;
}
inline poly poly_inv(poly A,const int k){
if(k==1)return poly(0,ksm(A[0],mod-2));
poly f0=poly_inv(A.extend((k+1)>>1),(k+1)>>1);
return (f0*2-(A*((f0*f0).extend(k))).extend(k)).extend(k);
}
};
int main(){
freopen("lx.in","r",stdin);
n=read()-1;
poly a(n);
for(ri i=0;i<=n;++i)a[i]=read();
init(n*2),a=a.extend(lim),a=a.poly_inv(a,lim);
for(ri i=0;i<=n;++i)cout<<a[i]<<' ';
return 0;
}