任意模数NTT 学习笔记

FFT可以用来快速计算卷积,有时候出题人会给出像998244353之类的良心模数,那么我们NTT就好了。
但是有些毒瘤偏不,他们不但给了模数,他们还给了不可以被拆成 x ∗ 2 k + 1 x*2^k+1 x2k+1形式的模数
这个时候就需要一些黑科技了。

拆系数FFT

嗯,不能边NTT边取模怎么办呢?
那我们就把它直接FFT完了再取模
高兴地写完代码一交,WA到自闭,额,好像溢出了,算一下最大值 n p 2 = 1 0 23 np^2=10^{23} np2=1023 显然已经爆long long了。

那么怎么把数弄的稍微小一点呢?这个时候就需要拆系数了
设一个阀值为 m = p m=\sqrt {p} m=p
可以把第一个多项式里每个系数都拆成 f i = a i ∗ m + b i f_i=a_i*m+b_i fi=aim+bi的形式
第二个多项式拆成 g i = c i ∗ m + d i g_i=c_i*m+d_i gi=cim+di的形式
因为比较懒不想写卷积形式,我们就算一下 f i ∗ g i f_i*g_i figi的值吧,两者显然是等价的
f i ∗ g i = ( a i ∗ m + b i ) ∗ ( c i ∗ m + d i ) = a i ∗ c i ∗ m 2 + ( a i ∗ d i + b i ∗ c i ) m + b i ∗ d i f_i*g_i=(a_i*m+b_i)*(c_i*m+d_i)=a_i*c_i*m^2+(a_i*d_i+b_i*c_i)m+b_i*d_i figi=(aim+bi)(cim+di)=aicim2+(aidi+bici)m+bidi
那么对ac、ad、bc、bd各跑一遍FFT,显然此时最大值变成了 n p = 1 0 14 np=10^{14} np=1014,然后按照上面公式合并取模肯定不会溢出
但是吧,上面跑了整整8遍FFT,这个常数有点大啊……
神仙毛爷爷在他的集训队论文里给出了神仙的优化方法,听说能够把FFT次数变到4次
学不动了……

三模数NTT

嗯,看到三模数就知道为什么了,还是要解决爆long long的问题,那我们就取3个1e9左右的模数,跑NTT,再用crt合并一下,还原出原数就可以了(学CRT强烈安利z(hou)z(hi)d(ao)的博客)
嗯,还原出原数,又爆long long了
这里可能需要点技巧
我们先合并前两组
a n s ≡ A ( m o d P ) ans\equiv A(mod P) ansA(modP)
a n s ≡ a 3 ( m o d p 3 ) ans\equiv a_3(mod p_3) ansa3(modp3)
可以设 a n s = t P + A = k p 3 + a 3 ans=tP+A=kp_3+a_3 ans=tP+A=kp3+a3
t P ≡ a 3 − A ( m o d p 3 ) tP\equiv a_3-A(mod p_3) tPa3A(modp3)
所以 t ≡ ( a 3 − A ) P − 1 ( m o d p 3 ) t\equiv (a_3-A)P^{-1}(modp_3) t(a3A)P1(modp3)
这样假设右边是 x x x
t = k p 3 + x t=kp_3+x t=kp3+x
代入 a n s = t P + A ans=tP+A ans=tP+A
a n s = ( k p 3 + x ) p 1 p 2 + A = k p 1 p 2 p 3 + x P + A ans=(kp_3+x)p_1p_2+A=kp_1p_2p_3+xP+A ans=(kp3+x)p1p2+A=kp1p2p3+xP+A
因为crt的范围在 [ 0 , p 1 p 2 p 3 ) [0,p_1p_2p_3) [0,p1p2p3) 所以k=0
a n s = x P + A ans=xP+A ans=xP+A 直接模p就可以啦
代入以后可以得到ans,合理使用龟速乘就可以不爆精度

比较一波可以感受到三模数NTT的常数(六次)应该会比拆系数FFT(四次)来得大
然而,我选择三模数NTT……
————upd
算错NTT次数了qwq
被直接卡爆
拆系数FFT不带黑科技的也来一份吧……

代码如下:

#include<bits/stdc++.h>
#define gg 3
#define N 300030
using namespace std;

long long ans[N],f[3][N],g[3][N],mod1[]={998244353,469762049,1004535809};
int r[N],n,m,p,lim;

inline long long mul(long long a,long long b,long long mod)
{
    long long res=a*b-(long long)((long double)a*b/mod+0.5)*mod;
    return res<0?res+mod:res;
}

long long kasumi(long long a,long long b,long long mod)
{
	long long ans=1;
	while(b)
	{
		if(b&1) ans=ans*a%mod;
		a=a*a%mod;
		b>>=1;
	}
	return ans;
}

void NTT(long long *a,int kd,int mod)
{
	for(int i=0;i<lim;i++)
	{
		if(i<r[i]) swap(a[i],a[r[i]]);
	}
	for(int mid=1;mid<lim;mid<<=1)
	{
		long long wn=kasumi(gg,(mod-1)/(mid<<1),mod);
		if(kd) wn=kasumi(wn,mod-2,mod);
		for(int i=0;i<lim;i+=mid<<1)
		{
			long long w=1;
			for(int j=0;j<mid;j++,w=wn*w%mod)
			{
				long long x=a[i+j];
				long long y=a[i+j+mid]*w%mod;
				a[i+j]=(x+y)%mod;
				a[i+j+mid]=(x-y+mod)%mod;
			}
		}
	}
	if(kd)
	{
		int inv=kasumi(lim,mod-2,mod);
		for(int i=0;i<lim;i++) a[i]=a[i]*inv%mod;
	}
}

int main()
{
	// freopen("ha.in","r",stdin);
	// freopen("ha.out","w",stdout);
	lim=1;
	scanf("%d%d%d",&n,&m,&p);
	for(int i=0;i<=n;i++) 
	{
		scanf("%lld",&f[0][i]);
		f[0][i]=f[1][i]=f[2][i]=f[0][i]%p;
	}
	for(int i=0;i<=m;i++)
	{
		scanf("%lld",&g[0][i]);
		g[0][i]=g[1][i]=g[2][i]=g[0][i]%p;
	}
	int cnt=0;
	while(lim<=(n+m)) lim<<=1,cnt++;
	for(int i=0;i<lim;i++)
	{
		r[i]=(r[i>>1]>>1)|((i&1)<<(cnt-1));
	}
	for(int i=0;i<=2;i++)
	{
		NTT(f[i],0,mod1[i]);NTT(g[i],0,mod1[i]);
		for(int j=0;j<lim;j++)
		{
			f[i][j]=f[i][j]*g[i][j]%mod1[i];
		}
		NTT(f[i],1,mod1[i]);
	}
	long long inv1=kasumi(mod1[0],mod1[1]-2,mod1[1]);
	long long inv2=kasumi(mod1[1],mod1[0]-2,mod1[0]);
	long long mul1=mod1[0]*mod1[1];
	for(int i=0;i<lim;i++)
	{
		ans[i]+=mul(f[0][i]*inv2%mul1,mod1[1],mul1);
		ans[i]+=mul(f[1][i]*inv1%mul1,mod1[0],mul1);
		ans[i]%=mul1;
	}
	long long inv3=kasumi(mul1%mod1[2],mod1[2]-2,mod1[2]);
	for(int i=0;i<lim;i++)
	{
		ans[i]=((f[2][i]-ans[i]%mod1[2]+mod1[2])%mod1[2]*inv3%mod1[2]*(mul1%p)%p+ans[i]%p)%p;
	}
	for(int i=0;i<=n+m;i++) printf("%lld ",ans[i]%p);
}
#include<cstdio>
#include<string>
#include<cmath>
#include<algorithm>
#define sz 32768
#define N 600030
using std::swap;

long long ans[N];
int r[N],n,m,p,lim,cnt;
long long ff[N],gg[N];

const long double pi=std::acos(-1);

struct comp
{
    long double r,i;
    comp(){}
    comp(long double a,long double b):r(a),i(b){}
}f[2][N],g[2][N],t1[N],t2[N],t3[N];

inline comp operator +(const comp a,const comp b) {return comp(a.r+b.r,a.i+b.i);}

inline comp operator -(const comp a,const comp b) {return comp(a.r-b.r,a.i-b.i);}

inline comp operator *(const comp a,const comp b) {return comp(a.r*b.r-a.i*b.i,a.r*b.i+b.r*a.i);}

void FFT(comp *a,int kd,int lim)
{
    for(int i=0;i<lim;i++) if(i<r[i]) swap(a[i],a[r[i]]);
    for(int mid=1;mid<lim;mid<<=1)
    {
        comp wn=comp(std::cos(pi/mid),kd*std::sin(pi/mid));
        for(int i=0;i<lim;i+=(mid<<1))
        {
            comp w=comp(1.0,0.0);
            for(int j=0;j<mid;j++,w=wn*w)
            {
                comp x=a[i+j];
                comp y=a[i+j+mid]*w;
                a[i+j]=x+y;
                a[i+j+mid]=x-y;
            }
        }
    }
    if(kd==-1)
    {
        for(int i=0;i<lim;i++)
        {
            a[i].r/=lim;
        }
    }
}


void mul1(long long *a,long long *b,int cnt)
{
    int lim=1<<cnt;
    for(int i=0;i<lim;i++)
    {
        f[0][i].r=a[i]/sz;
        f[1][i].r=a[i]%sz;
        g[0][i].r=b[i]/sz;
        g[1][i].r=b[i]%sz;
        ans[i]=0;
    }
    for(int i=0;i<lim;i++) r[i]=(r[i>>1]>>1)|((i&1)<<(cnt-1));
    FFT(f[0],1,lim);FFT(f[1],1,lim);
    FFT(g[0],1,lim);FFT(g[1],1,lim);
    for(int i=0;i<lim;i++)
    {
        t1[i]=f[0][i]*g[0][i];
        t2[i]=f[0][i]*g[1][i]+g[0][i]*f[1][i];
        t3[i]=f[1][i]*g[1][i];
    }
    FFT(t1,-1,lim);FFT(t2,-1,lim);FFT(t3,-1,lim);
    for(int i=0;i<lim;i++)
    {
        ans[i]=(((long long)(t1[i].r+0.5))%p*sz%p*sz%p+(((long long)(t2[i].r+0.5))%p*sz%p)+(long long)(t3[i].r+0.5)%p)%p;
    }
}

int main()
{
    int lim=1,cnt=0;
    scanf("%d%d%d",&n,&m,&p);
    for(int i=0;i<=n;i++) scanf("%lld",&ff[i]);
    for(int i=0;i<=m;i++) scanf("%lld",&gg[i]);
    while(lim<(n+m)) lim<<=1,cnt++;
    mul1(ff,gg,cnt);
    for(int i=0;i<=n+m;i++) printf("%lld ",ans[i]);
}
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值