P5394 【模板】下降幂多项式乘法

题目不好复制

要吸吸氧才能过。
这种写法比第二种写法少一次ntt。

#include<iostream>
#include<cstdio>
#include<algorithm>
#include<cstring>
#include<cmath>
#include<string>
#include<queue>
#define ll long long
#define llu unsigned ll
#define int ll
using namespace std;
const int maxn=8e5+1000;
const int p=998244353;
const int g=3;
int fi[maxn];
int a[maxn],b[maxn],c[maxn];
int fac[maxn],inv[maxn];

int mypow(int a,int b)
{
    if(b<0) return mypow(mypow(a,p-2),-b);
    int ans=1;
    while(b)
    {
        if(b&1) ans=ans*a%p;
        a=a*a%p;
        b>>=1;
    }
    return ans%p;
}

void init(int n)
{
    fac[0]=1;
    for(int i=1;i<=n;i++)
        fac[i]=fac[i-1]*i%p;
    inv[n]=mypow(fac[n],p-2);
    for(int i=n-1;i>=0;i--)
        inv[i]=inv[i+1]*(i+1)%p;
}

void ntt(int *x,int len,int f)
{
    for(int i=0;i<len;i++)
        if(i<fi[i]) swap(x[i],x[fi[i]]);

    for(int i=1;i<len;i<<=1)
    {
        int r=i<<1;
        int wn=mypow(g,f*(p-1)/r);
        for(int j=0;j<len;j+=r)
        {
            int w=1;
            for(int k=0;k<i;k++)
            {
                int xx=x[j+k],yy=w*x[j+i+k]%p;
                x[j+k]=(xx+yy)%p;
                x[j+i+k]=((xx-yy)%p+p)%p;
                w=w*wn%p;
            }
        }
    }
    if(f==-1)
    {
        int invn=mypow(len,p-2);
        for(int i=0;i<len;i++)
            x[i]=x[i]*invn%p;
    }
}


signed main(void)
{
    int n,m;
    scanf("%lld%lld",&n,&m);
    for(int i=0;i<=n;i++)
        scanf("%lld",&a[i]);
    for(int i=0;i<=m;i++)
        scanf("%lld",&b[i]);

    int now=n+m;
    n=now,m=now;
    int len=1,cnt=0;
    while(len<=n+m) len<<=1,cnt++;
    for(int i=0;i<len;i++)
        fi[i]=((fi[i>>1]>>1)|((i&1)<<(cnt-1)));

    init(now);
    for(int i=0;i<=now;i++)
        c[i]=inv[i];
    ntt(a,len,1);
    ntt(b,len,1);
    ntt(c,len,1);
    for(int i=0;i<len;i++)
        a[i]=a[i]*c[i]%p,b[i]=b[i]*c[i]%p;
    ntt(a,len,-1);
    ntt(b,len,-1);
    for(int i=0;i<=now;i++)
        a[i]=a[i]*b[i]%p*fac[i]%p;
    for(int i=now+1;i<len;i++)
        a[i]=0,c[i]=0;


    for(int i=0;i<=now;i++)
    {
        if(i&1) c[i]=-inv[i]+p;
        else c[i]=inv[i];
    }
    ntt(a,len,1);
    ntt(c,len,1);
    for(int i=0;i<len;i++)
        a[i]=a[i]*c[i]%p;
    ntt(a,len,-1);


    for(int i=0;i<=now;i++)
        printf("%lld ",a[i]);
    putchar('\n');
    return 0;

}


只是比上面那种写法多一次ntt而已。

#include<iostream>
#include<cstdio>
#include<algorithm>
#include<cstring>
#include<cmath>
#include<string>
#include<queue>
#define ll long long
#define llu unsigned ll
#define int ll
using namespace std;
const int maxn=8e5+1000;
const int p=998244353;
const int g=3;
int fi[maxn];
int a[maxn],b[maxn],c[maxn];
int fac[maxn],inv[maxn];

int mypow(int a,int b)
{
    if(b<0) return mypow(mypow(a,p-2),-b);
    int ans=1;
    while(b)
    {
        if(b&1) ans=ans*a%p;
        a=a*a%p;
        b>>=1;
    }
    return ans%p;
}

void init(int n)
{
    fac[0]=1;
    for(int i=1;i<=n;i++)
        fac[i]=fac[i-1]*i%p;
    inv[n]=mypow(fac[n],p-2);
    for(int i=n-1;i>=0;i--)
        inv[i]=inv[i+1]*(i+1)%p;
}

void ntt(int *x,int len,int f)
{
    for(int i=0;i<len;i++)
        if(i<fi[i]) swap(x[i],x[fi[i]]);

    for(int i=1;i<len;i<<=1)
    {
        int r=i<<1;
        int wn=mypow(g,f*(p-1)/r);
        for(int j=0;j<len;j+=r)
        {
            int w=1;
            for(int k=0;k<i;k++)
            {
                int xx=x[j+k],yy=w*x[j+i+k]%p;
                x[j+k]=(xx+yy)%p;
                x[j+i+k]=((xx-yy)%p+p)%p;
                w=w*wn%p;
            }
        }
    }
    if(f==-1)
    {
        int invn=mypow(len,p-2);
        for(int i=0;i<len;i++)
            x[i]=x[i]*invn%p;
    }
}

void getntt(int *a,int len,int n,int f)
{
    if(f==1)
    {
        for(int i=0;i<=n;i++)
            c[i]=inv[i];
    }
    else
    {
        for(int i=0;i<=n;i++)
        {
            if(i&1) c[i]=-inv[i]+p;
            else c[i]=inv[i];
        }
    }

    for(int i=n+1;i<=len;i++)
        c[i]=0;

    ntt(a,len,1);
    ntt(c,len,1);
    for(int i=0;i<len;i++)
        a[i]=a[i]*c[i]%p;
    ntt(a,len,-1);
}

signed main(void)
{
    int n,m;
    scanf("%lld%lld",&n,&m);
    for(int i=0;i<=n;i++)
        scanf("%lld",&a[i]);
    for(int i=0;i<=m;i++)
        scanf("%lld",&b[i]);

    int now=n+m;
    n=now,m=now;
    int len=1,cnt=0;
    while(len<=n+m) len<<=1,cnt++;
    for(int i=0;i<len;i++)
        fi[i]=((fi[i>>1]>>1)|((i&1)<<(cnt-1)));

    init(now);
    getntt(a,len,now,1);
    getntt(b,len,now,1);
    for(int i=0;i<=now;i++)
        a[i]=a[i]*b[i]%p*fac[i]%p;
    for(int i=now+1;i<len;i++)
        a[i]=0;
    getntt(a,len,now,-1);
    for(int i=0;i<=now;i++)
        printf("%lld ",a[i]);
    putchar('\n');
    return 0;

}

加上快读,再加上预处理原根的阶乘,也没卡过。。。

#include<iostream>
#include<cstdio>
#include<algorithm>
#include<cstring>
#include<cmath>
#include<string>
#include<queue>
#define ll long long
#define llu unsigned ll
#define int ll
using namespace std;

char buffer[100001],*S,*T;
inline char Get_Char()
{
    if (S==T)
    {
        T=(S=buffer)+fread(buffer,1,100001,stdin);
        if (S==T) return EOF;
    }
    return *S++;
}
inline int read()
{
    char c;int re=0;
    for(c=Get_Char();c<'0'||c>'9';c=Get_Char());
    while(c>='0'&&c<='9') re=re*10+(c-'0'),c=Get_Char();
    return re;
}


const int maxn=8e5+1000;
const int p=998244353;
const int g=3;
int fi[maxn];
int a[maxn],b[maxn],c[maxn];
int fac[maxn],inv[maxn],gg[2][maxn];

int mypow(int a,int b)
{
    if(b<0) return mypow(mypow(a,p-2),-b);
    int ans=1;
    while(b)
    {
        if(b&1) ans=ans*a%p;
        a=a*a%p;
        b>>=1;
    }
    return ans%p;
}

void init(int n,int m)
{
    fac[0]=1;
    for(int i=1;i<=n;i++)
        fac[i]=fac[i-1]*i%p;
    inv[n]=mypow(fac[n],p-2);
    for(int i=n-1;i>=0;i--)
        inv[i]=inv[i+1]*(i+1)%p;
    for(int i=1;i<(m<<1);i<<=1)
        gg[0][i]=mypow(g,-(p-1)/i),gg[1][i]=mypow(g,(p-1)/i);
}

void ntt(int *x,int len,int f)
{
    for(int i=0;i<len;i++)
        if(i<fi[i]) swap(x[i],x[fi[i]]);

    for(int i=1;i<len;i<<=1)
    {
        int r=i<<1;
        int wn=gg[f][r];
        for(int j=0;j<len;j+=r)
        {
            int w=1;
            for(int k=0;k<i;k++)
            {
                int xx=x[j+k],yy=w*x[j+i+k]%p;
                x[j+k]=(xx+yy)%p;
                x[j+i+k]=((xx-yy)%p+p)%p;
                w=w*wn%p;
            }
        }
    }
    if(f==0)
    {
        int invn=mypow(len,p-2);
        for(int i=0;i<len;i++)
            x[i]=x[i]*invn%p;
    }
}


signed main(void)
{
    int n,m;
    n=read(),m=read();
    for(int i=0;i<=n;i++)
        a[i]=read();
    for(int i=0;i<=m;i++)
        b[i]=read();

    int now=n+m;
    n=now,m=now;
    int len=1,cnt=0;
    while(len<=n+m) len<<=1,cnt++;
    for(int i=0;i<len;i++)
        fi[i]=((fi[i>>1]>>1)|((i&1)<<(cnt-1)));

    init(now,len);
    for(int i=0;i<=now;i++)
        c[i]=inv[i];
    ntt(a,len,1);
    ntt(b,len,1);
    ntt(c,len,1);
    for(int i=0;i<len;i++)
        a[i]=a[i]*c[i]%p,b[i]=b[i]*c[i]%p;
    ntt(a,len,0);
    ntt(b,len,0);
    for(int i=0;i<=now;i++)
        a[i]=a[i]*b[i]%p*fac[i]%p;
    for(int i=now+1;i<len;i++)
        a[i]=0,c[i]=0;


    for(int i=0;i<=now;i++)
    {
        if(i&1) c[i]=-inv[i]+p;
        else c[i]=inv[i];
    }
    ntt(a,len,1);
    ntt(c,len,1);
    for(int i=0;i<len;i++)
        a[i]=a[i]*c[i]%p;
    ntt(a,len,0);


    for(int i=0;i<=now;i++)
        printf("%lld ",a[i]);
    putchar('\n');
    return 0;

}



  • 0
    点赞
  • 1
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值