多项式求逆&多项式开方

多项式求逆


运用基于倍增的多项式求逆可以在 O(nlogn) O ( n log ⁡ n ) 时间内,对于一个 n n 次多项式A(x)求出 B(x) B ( x ) 使得 B(x)A(x)=1 B ( x ) ⋅ A ( x ) = 1
假设当前我们已经求出了在 modxn mod x n 意义下 A(x) A ( x ) 的逆 B(x) B ′ ( x ) ,考虑怎么求 modx2n mod x 2 n 意义下的逆 B(x) B ( x )

A(x)B(x)1(modxn) , A(x)B(x)1(modxn) A ( x ) ⋅ B ′ ( x ) ≡ 1 ( mod x n )   ,   A ( x ) ⋅ B ( x ) ≡ 1 ( mod x n )

两式相减得
A(x)(B(x)B(x))0(modxn)B(x)B(x)0(modxn) A ( x ) ⋅ ( B ( x ) − B ′ ( x ) ) ≡ 0 ( mod x n ) → B ( x ) − B ′ ( x ) ≡ 0 ( mod x n )

两边平方(平方后模数也可以 2 ∗ 2
B2(x)2B(x)B(x)+B2(x)0(modx2n) B 2 ( x ) − 2 B ( x ) B ′ ( x ) + B ′ 2 ( x ) ≡ 0 ( mod x 2 n )

在两边同乘 A(x) A ( x )
B(x)2B(x)+A(x)B2(x)0(modx2n)B(x)2B(x)A(x)B2(x)(modx2n) B ( x ) − 2 B ′ ( x ) + A ( x ) B ′ 2 ( x ) ≡ 0 ( mod x 2 n ) → B ( x ) ≡ 2 B ′ ( x ) − A ( x ) B ′ 2 ( x ) ( mod x 2 n )

于是我们解出了 B(x) B ( x ) ,通过FFT可以做到 O(nlogn) O ( n log ⁡ n ) ,总复杂度 T(n)=T(n2)+O(nlogn)=O(nlogn) T ( n ) = T ( n 2 ) + O ( n log ⁡ n ) = O ( n log ⁡ n )

代码:

void get_inv(int *a,int *b,int n)
{
    if(n==1){b[0]=ksm(a[0],mod-2);return ;}
    get_inv(a,b,n>>1);n<<=1;
    for(int i=0;i<(n>>1);i++)
        t_I[i]=a[i],t_I[i+(n>>1)]=0;
    ntt(t_I,n,1);ntt(b,n,1);
    for(int i=0;i<n;i++)
        b[i]=((b[i]<<1)-(ll)t_I[i]*b[i]%mod*b[i]%mod+mod)%mod;
    ntt(b,n,-1);
    for(int i=(n>>1);i<n;i++)
        b[i]=0;     
}

多项式开方


同样利用倍增,可以在 O(nlogn) O ( n log ⁡ n ) 时间内,对于一个 n n 次多项式A(x),求出 B(x) B ( x ) 使得 B2(x)=A(x) B 2 ( x ) = A ( x )

B2(x)=A(x)(modxn),B2(x)=A(x)(modx2n) B ′ 2 ( x ) = A ( x ) ( mod x n ) , B 2 ( x ) = A ( x ) ( mod x 2 n )

同理可得
B2(x)2B(x)B(x)+B2(x)0(modx2n) B 2 ( x ) − 2 B ( x ) B ′ ( x ) + B ′ 2 ( x ) ≡ 0 ( mod x 2 n )

B2(x) B 2 ( x ) 就是 A(x) A ( x ) ,于是可以解出
B(x)=A(x)+B2(x)2B(x) B ( x ) = A ( x ) + B ′ 2 ( x ) 2 B ′ ( x )

需要求一下 B(x) B ′ ( x ) 的逆。
代码:

void get_sqrt(int *a,int *b,int n)
{
    if(n==1){b[0]=1;return;}
    get_sqrt(a,b,n>>1);n<<=1;
    memset(t_S,0,sizeof(int)*n);
    get_inv(b,t_S,n>>1);   
    for(int i=0;i<(n>>1);i++) 
        t_I[i]=a[i],t_I[i+(n>>1)]=0;
    ntt(t_I,n,1);ntt(t_S,n,1);ntt(b,n,1);
    for(int i=0;i<n;i++)
        b[i]=((ll)b[i]*b[i]+t_I[i])%mod*t_S[i]%mod*inv_2%mod;
    ntt(b,n,-1);
    for(int i=(n>>1);i<n;i++)
        b[i]=0; 
}

还有几个要注意的地方:
1. 从 n2 n 2 次推到 n n 次的时候,fft要开到2n,因为有三个东西相乘。
2. 最后要把 n ≥ n 次的项赋为 0 0 ,就是要记得(modxn)

模板题:CF438E
代码:

#include<iostream>
#include<cstdio>
#include<cstring>
#define N 262150
#define ll long long
using namespace std;
const int mod=998244353;
const int g=3;
int n,m,r[N],inv_2;
int h[N],f[N],t_I[N],t_S[N];
int ksm(ll a,int b){ll r=1;for(b=(b+mod-1)%(mod-1);b;b>>=1){if(b&1)r=r*a%mod;a=a*a%mod;}return r;}
void ntt(int *a,int n,int dft)
{
    for(int i=0;i<n;i++)
        r[i]=(r[i>>1]>>1)|((i&1)*(n>>1));
    for(int i=0;i<n;i++)
        if(i<r[i]) swap(a[i],a[r[i]]);
    for(int i=1;i<n;i<<=1)
    {
        int wn=ksm(g,(mod-1)/(i<<1)*dft);
        for(int j=0;j<n;j+=(i<<1))
        {
            int wk=1;
            for(int k=j;k<j+i;k++,wk=(ll)wk*wn%mod)
            {
                ll x=a[k],y=(ll)wk*a[k+i]%mod;
                a[k]=(x+y)%mod;a[k+i]=(x-y+mod)%mod;
            }
        }
    }
    if(dft==-1) for(int i=0,inv=ksm(n,mod-2);i<n;i++) a[i]=(ll)a[i]*inv%mod;
}
void get_inv(int *a,int *b,int n)
{
    if(n==1){b[0]=ksm(a[0],mod-2);return ;}
    get_inv(a,b,n>>1);n<<=1;
    for(int i=0;i<(n>>1);i++)
        t_I[i]=a[i],t_I[i+(n>>1)]=0;
    ntt(t_I,n,1);ntt(b,n,1);
    for(int i=0;i<n;i++)
        b[i]=((b[i]<<1)-(ll)t_I[i]*b[i]%mod*b[i]%mod+mod)%mod;
    ntt(b,n,-1);
    for(int i=(n>>1);i<n;i++)
        b[i]=0;       
}
void get_sqrt(int *a,int *b,int n)
{
    if(n==1){b[0]=1;return;}
    get_sqrt(a,b,n>>1);n<<=1;
    memset(t_S,0,sizeof(int)*n);
    get_inv(b,t_S,n>>1);  
    for(int i=0;i<(n>>1);i++) 
        t_I[i]=a[i],t_I[i+(n>>1)]=0;
    ntt(t_I,n,1);ntt(t_S,n,1);ntt(b,n,1);
    for(int i=0;i<n;i++)
        b[i]=((ll)b[i]*b[i]+t_I[i])%mod*t_S[i]%mod*inv_2%mod;
    ntt(b,n,-1);
    for(int i=(n>>1);i<n;i++)
        b[i]=0;
}
int main()
{
    scanf("%d%d",&n,&m);
    inv_2=ksm(2,mod-2);
    h[0]=1;
    for(int i=1;i<=n;i++)
    {
        int x;
        scanf("%d",&x);
        h[x]=(h[x]-4+mod)%mod;
    }
    for(n=1;n<=m;n<<=1);
    get_sqrt(h,f,n);
    memcpy(h,f,sizeof(h));
    memset(f,0,sizeof(f));  
    h[0]=(h[0]+1)%mod;
    get_inv(h,f,n);
    for(int i=1;i<=m;i++)
        printf("%d\n",(f[i]<<1)%mod);
    return 0;
}
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值