多项式求逆
运用基于倍增的多项式求逆可以在
O(nlogn)
O
(
n
log
n
)
时间内,对于一个
n
n
次多项式求出
B(x)
B
(
x
)
使得
B(x)⋅A(x)=1
B
(
x
)
⋅
A
(
x
)
=
1
。
假设当前我们已经求出了在
modxn
mod
x
n
意义下
A(x)
A
(
x
)
的逆
B′(x)
B
′
(
x
)
,考虑怎么求
modx2n
mod
x
2
n
意义下的逆
B(x)
B
(
x
)
。
两式相减得
两边平方(平方后模数也可以 ∗2 ∗ 2 )
在两边同乘 A(x) A ( x )
于是我们解出了 B(x) B ( x ) ,通过FFT可以做到 O(nlogn) O ( n log n ) ,总复杂度 T(n)=T(n2)+O(nlogn)=O(nlogn) T ( n ) = T ( n 2 ) + O ( n log n ) = O ( n log n ) 。
代码:
void get_inv(int *a,int *b,int n)
{
if(n==1){b[0]=ksm(a[0],mod-2);return ;}
get_inv(a,b,n>>1);n<<=1;
for(int i=0;i<(n>>1);i++)
t_I[i]=a[i],t_I[i+(n>>1)]=0;
ntt(t_I,n,1);ntt(b,n,1);
for(int i=0;i<n;i++)
b[i]=((b[i]<<1)-(ll)t_I[i]*b[i]%mod*b[i]%mod+mod)%mod;
ntt(b,n,-1);
for(int i=(n>>1);i<n;i++)
b[i]=0;
}
多项式开方
同样利用倍增,可以在
O(nlogn)
O
(
n
log
n
)
时间内,对于一个
n
n
次多项式,求出
B(x)
B
(
x
)
使得
B2(x)=A(x)
B
2
(
x
)
=
A
(
x
)
。
同理可得
B2(x) B 2 ( x ) 就是 A(x) A ( x ) ,于是可以解出
需要求一下 B′(x) B ′ ( x ) 的逆。
代码:
void get_sqrt(int *a,int *b,int n)
{
if(n==1){b[0]=1;return;}
get_sqrt(a,b,n>>1);n<<=1;
memset(t_S,0,sizeof(int)*n);
get_inv(b,t_S,n>>1);
for(int i=0;i<(n>>1);i++)
t_I[i]=a[i],t_I[i+(n>>1)]=0;
ntt(t_I,n,1);ntt(t_S,n,1);ntt(b,n,1);
for(int i=0;i<n;i++)
b[i]=((ll)b[i]*b[i]+t_I[i])%mod*t_S[i]%mod*inv_2%mod;
ntt(b,n,-1);
for(int i=(n>>1);i<n;i++)
b[i]=0;
}
还有几个要注意的地方:
1. 从
n2
n
2
次推到
n
n
次的时候,fft要开到,因为有三个东西相乘。
2. 最后要把
≥n
≥
n
次的项赋为
0
0
,就是要记得。
模板题:CF438E
代码:
#include<iostream>
#include<cstdio>
#include<cstring>
#define N 262150
#define ll long long
using namespace std;
const int mod=998244353;
const int g=3;
int n,m,r[N],inv_2;
int h[N],f[N],t_I[N],t_S[N];
int ksm(ll a,int b){ll r=1;for(b=(b+mod-1)%(mod-1);b;b>>=1){if(b&1)r=r*a%mod;a=a*a%mod;}return r;}
void ntt(int *a,int n,int dft)
{
for(int i=0;i<n;i++)
r[i]=(r[i>>1]>>1)|((i&1)*(n>>1));
for(int i=0;i<n;i++)
if(i<r[i]) swap(a[i],a[r[i]]);
for(int i=1;i<n;i<<=1)
{
int wn=ksm(g,(mod-1)/(i<<1)*dft);
for(int j=0;j<n;j+=(i<<1))
{
int wk=1;
for(int k=j;k<j+i;k++,wk=(ll)wk*wn%mod)
{
ll x=a[k],y=(ll)wk*a[k+i]%mod;
a[k]=(x+y)%mod;a[k+i]=(x-y+mod)%mod;
}
}
}
if(dft==-1) for(int i=0,inv=ksm(n,mod-2);i<n;i++) a[i]=(ll)a[i]*inv%mod;
}
void get_inv(int *a,int *b,int n)
{
if(n==1){b[0]=ksm(a[0],mod-2);return ;}
get_inv(a,b,n>>1);n<<=1;
for(int i=0;i<(n>>1);i++)
t_I[i]=a[i],t_I[i+(n>>1)]=0;
ntt(t_I,n,1);ntt(b,n,1);
for(int i=0;i<n;i++)
b[i]=((b[i]<<1)-(ll)t_I[i]*b[i]%mod*b[i]%mod+mod)%mod;
ntt(b,n,-1);
for(int i=(n>>1);i<n;i++)
b[i]=0;
}
void get_sqrt(int *a,int *b,int n)
{
if(n==1){b[0]=1;return;}
get_sqrt(a,b,n>>1);n<<=1;
memset(t_S,0,sizeof(int)*n);
get_inv(b,t_S,n>>1);
for(int i=0;i<(n>>1);i++)
t_I[i]=a[i],t_I[i+(n>>1)]=0;
ntt(t_I,n,1);ntt(t_S,n,1);ntt(b,n,1);
for(int i=0;i<n;i++)
b[i]=((ll)b[i]*b[i]+t_I[i])%mod*t_S[i]%mod*inv_2%mod;
ntt(b,n,-1);
for(int i=(n>>1);i<n;i++)
b[i]=0;
}
int main()
{
scanf("%d%d",&n,&m);
inv_2=ksm(2,mod-2);
h[0]=1;
for(int i=1;i<=n;i++)
{
int x;
scanf("%d",&x);
h[x]=(h[x]-4+mod)%mod;
}
for(n=1;n<=m;n<<=1);
get_sqrt(h,f,n);
memcpy(h,f,sizeof(h));
memset(f,0,sizeof(f));
h[0]=(h[0]+1)%mod;
get_inv(h,f,n);
for(int i=1;i<=m;i++)
printf("%d\n",(f[i]<<1)%mod);
return 0;
}