任意模数NTT(MTT)

版权声明:本文为博主原创文章,未经博主允许不得转载。 https://blog.csdn.net/zhouyuheng2003/article/details/85561887

前言

众所周知,NTT有几个经典的模数:469762049,998244353,1004535809469762049,998244353,1004535809
为什么这些模数被称为NTT模数呢?因为他们都是这样一个形式:
P=2aX+1P=2^a*X+1
为什么要有这样一个条件呢,因为只有这样,才能找到所需的原根
所以对于一般的一个模数P=2aX+1P=2^a*X+1,能适用的最大的多项式长度(包括结果)是2a2^a
有时候,
给出的多项式长度超过限制,我们就不能用裸的NTT了
一般有两种情况:

  • 模数是NTT模数,但是多项式长度略超出限制(比如模数是1004535809,输入多项式长度和>2097152)
  • 模数不是NTT模数,比如模数是1000000007
    这个时候任意模数NTT就非常有用了

正文

我们来分析任意模数NTT做法的思路

思路一(P不是很大的时候)

根据分析,我们发现,多项式长度为N、模数为P的时候,多项式乘法的结果每一项的值0xPN20\le x\le PN^2
由于NTT的复杂度是Θ(nlogn)\Theta(nlogn)的,所以nn的范围可以出到10510^5以上,而对于10910^9级别的质数,那么结果大约是102310^{23}级别的。如果不考虑值域,有个很好的思路是:先进行FFT,算完后取模
非常不幸的是,由于结果的值域过大,FFT的精度往往都不够(这也是为什么要使用NTT的原因,根据实测,使用long double的FFT,当值域1013\le 10^{13}的时候,FFT是精度较好的,值域更大的时候出错概率就会比较高了,Tip:FFT的精度并不只与值域相关,多项式长度同样会影响精度(似乎是在Pi/n这个地方损失了精度),博主对各个长度都进行了测试,取min值
写一些上界(粗略)
N=1000000N=1000000 X=6000X=6000
N=100000N=100000 X=40000X=40000
N=10000N=10000 X=300000X=300000
N=1000N=1000 X=1000000X=1000000
N=100N=100 X=6000000X=6000000
N=10N=10 X=20000000X=20000000
(精度值在1000000下为6000,对拍程序为NTT)
所以如果你发现质数不是很大,即PN21013PN^2\le 10^{13}的时候,你可以放心的FFT(本测试的多项式长度上限为10610^6
注意:实际对于一般的FFT,保守限制为101010^{10},因为long double 可能会出现莫名的错误(博主太菜了,写的代码就出现UKE,有关using namespace std和std::的,导致精度大大下降)

思路二(基于FFT的优化)(2.6倍的普通FFT,多项式长度受限较大)

贴出一道模板题,本思路以及可以解决模板题了
洛谷模板题:任意模数NTT
我们发现FFT并不是一无是处,所以我们考虑压缩值域
p=Pp=\left\lceil\sqrt P\right\rceil
那么任何一个数都能表示成X=axp+bx(X<P,ax<p,bx<p)X=a_xp+b_x(X<P,a_x<p,b_x<p)的形式
那么我们考虑结果的一个值,对其进行分析
V=(axp+bx)(ayp+by)=axayp2+(axby+bxay)p+bxby\begin{aligned} V&=\sum(a_xp+b_x)*(a_yp+b_y)\\ &=\sum a_xa_yp^2+(a_xb_y+b_xa_y)p+b_xb_y\\ \end{aligned}
我们对每一组系数都进行计算,容易发现NN10510^5级别的,值域p<40000p<40000,所以可以求出各项,然后再加起来
贴出AC代码

#include<cstdio>
#include<cctype>
#include<cmath> 
#include<algorithm>
using namespace std;
namespace fast_IO
{
    const int IN_LEN=10000000,OUT_LEN=10000000;
    char ibuf[IN_LEN],obuf[OUT_LEN],*ih=ibuf+IN_LEN,*oh=obuf,*lastin=ibuf+IN_LEN,*lastout=obuf+OUT_LEN-1;
    inline char getchar_(){return (ih==lastin)&&(lastin=(ih=ibuf)+fread(ibuf,1,IN_LEN,stdin),ih==lastin)?EOF:*ih++;}
    inline void putchar_(const char x){if(oh==lastout)fwrite(obuf,1,oh-obuf,stdout),oh=obuf;*oh++=x;}
    inline void flush(){fwrite(obuf,1,oh-obuf,stdout);}
}
using namespace fast_IO;
#define getchar() getchar_()
#define putchar(x) putchar_((x))
#define rg register
typedef long long LL;
typedef long double LD;
#define double LD
template <typename T> inline T max(const T a,const T b){return a>b?a:b;}
template <typename T> inline T min(const T a,const T b){return a<b?a:b;}
template <typename T> inline void mind(T&a,const T b){a=a<b?a:b;}
template <typename T> inline void maxd(T&a,const T b){a=a>b?a:b;}
template <typename T> inline T abs(const T a){return a>0?a:-a;}
template <typename T> inline T gcd(const T a,const T b){if(!b)return a;return gcd(b,a%b);}
template <typename T> inline T lcm(const T a,const T b){return a/gcd(a,b)*b;}
template <typename T> inline T square(const T x){return x*x;};
template <typename T> inline void read(T&x)
{
    char cu=getchar();x=0;bool fla=0;
    while(!isdigit(cu)){if(cu=='-')fla=1;cu=getchar();}
    while(isdigit(cu))x=x*10+cu-'0',cu=getchar();
    if(fla)x=-x;
}
template <typename T> inline void printe(const T x)
{
    if(x>=10)printe(x/10);
    putchar(x%10+'0');
}
template <typename T> inline void print(const T x)
{
    if(x<0)putchar('-'),printe(-x);
    else printe(x);
}
const int maxn=262145;const double PI=acos((LD)-1.0);
int n,m;
struct complex
{
    double x,y;
    inline complex operator +(const complex b)const{return (complex){x+b.x,y+b.y};}
    inline complex operator -(const complex b)const{return (complex){x-b.x,y-b.y};}
    inline complex operator *(const complex b)const{return (complex){x*b.x-y*b.y,x*b.y+y*b.x};}
}ax[maxn],ay[maxn],bx[maxn],by[maxn];
int lenth=1,Reverse[maxn];
inline void init(const int x)
{
    rg int tim=0;
    while(lenth<=x)lenth<<=1,tim++;
    for(rg int i=0;i<lenth;i++)Reverse[i]=(Reverse[i>>1]>>1)|((i&1)<<(tim-1));
}
inline void FFT(complex*A,const int fla)
{
    for(rg int i=0;i<lenth;i++)if(i<Reverse[i])swap(A[i],A[Reverse[i]]);
    for(rg int i=1;i<lenth;i<<=1)
    {
        const complex w=(complex){cos(PI/i),fla*sin(PI/i)};
        for(rg int j=0;j<lenth;j+=(i<<1))
        {
            complex K=(complex){1,0};
            for(rg int k=0;k<i;k++,K=K*w)
            {
                const complex x=A[j+k],y=A[j+k+i]*K;
                A[j+k]=x+y;
                A[j+k+i]=x-y;
            }
        }
    }
}
int P,p;
int main()
{
    read(n),read(m),read(P);
    p=31624;
    init(n+m);
    for(rg int i=0;i<=n;i++)
    { 
        int x;read(x);
        ax[i].x=x/p,bx[i].x=x%p;
    }
    for(rg int i=0;i<=m;i++)
    {
        int x;read(x);
        ay[i].x=x/p,by[i].x=x%p;
    }
    FFT(ax,1),FFT(bx,1),FFT(ay,1),FFT(by,1);
    for(rg int i=0;i<lenth;i++)
    {
        const complex A=ax[i],B=bx[i],C=ay[i],D=by[i];
        ax[i]=A*C,ay[i]=B*D;
        bx[i]=A*D,by[i]=B*C;
    }
    FFT(ax,-1),FFT(bx,-1),FFT(ay,-1),FFT(by,-1);
    for(rg int i=0;i<=n+m;i++)
    {
        const LL A=ax[i].x/lenth+0.5,B=ay[i].x/lenth+0.5,C=bx[i].x/lenth+0.5,D=by[i].x/lenth+0.5;
        print((A%P*p%P*p%P+B%P+(C%P+D%P)*p%P)%P),putchar(' ');
    }
    return flush(),0;
}

这里又出现了UKE!!!博主太菜啦
如果p的赋值写成P\left\lceil\sqrt P\right\rceil,就会会在洛谷上WA两个点
如果发现我哪里写挂了,请速联系我!
效率分析,一次一般的多项式乘法共调用3次FFT函数,这里调用了8次,所以这种任意模数NTT算法常数大概是2.6左右
update by 2019.1.7:可以通过一些技巧减小精度损失以支持更多位数或在当前位数下只使用double(tip by yx2003)
详细方法:将FFT函数中的K直接预处理即可,减少乘法中的精度损失,对多项式长度较长(100000及以上) 的情况有较大优化效果
为什么呢?大概这个精度是受限两个方面:一个是值域上限的限制(在多项式长度较小,值域较大时体现),一个是多项式长度的限制(在多项式长度较大时体现)。容易发现多项式长的时候精度受限在单位根上,这个优化就是针对单位根精度的优化
提升效果:
N=100000N=100000
X=40000100000X=40000\Rightarrow100000

N=1000000N=1000000
X=600030000X=6000\Rightarrow30000
对于模板题的速度也能有较大提升,大约用时是原来的12\frac12
代码

#include<cstdio>
#include<cctype>
#include<cmath> 
#include<algorithm>
using namespace std;
namespace fast_IO
{
    const int IN_LEN=10000000,OUT_LEN=10000000;
    char ibuf[IN_LEN],obuf[OUT_LEN],*ih=ibuf+IN_LEN,*oh=obuf,*lastin=ibuf+IN_LEN,*lastout=obuf+OUT_LEN-1;
    inline char getchar_(){return (ih==lastin)&&(lastin=(ih=ibuf)+fread(ibuf,1,IN_LEN,stdin),ih==lastin)?EOF:*ih++;}
    inline void putchar_(const char x){if(oh==lastout)fwrite(obuf,1,oh-obuf,stdout),oh=obuf;*oh++=x;}
    inline void flush(){fwrite(obuf,1,oh-obuf,stdout);}
}
using namespace fast_IO;
#define getchar() getchar_()
#define putchar(x) putchar_((x))
#define rg register
typedef long long LL;
template <typename T> inline T max(const T a,const T b){return a>b?a:b;}
template <typename T> inline T min(const T a,const T b){return a<b?a:b;}
template <typename T> inline void mind(T&a,const T b){a=a<b?a:b;}
template <typename T> inline void maxd(T&a,const T b){a=a>b?a:b;}
template <typename T> inline T abs(const T a){return a>0?a:-a;}
template <typename T> inline T gcd(const T a,const T b){if(!b)return a;return gcd(b,a%b);}
template <typename T> inline T lcm(const T a,const T b){return a/gcd(a,b)*b;}
template <typename T> inline T square(const T x){return x*x;};
template <typename T> inline void read(T&x)
{
    char cu=getchar();x=0;bool fla=0;
    while(!isdigit(cu)){if(cu=='-')fla=1;cu=getchar();}
    while(isdigit(cu))x=x*10+cu-'0',cu=getchar();
    if(fla)x=-x;
}
template <typename T> inline void printe(const T x)
{
    if(x>=10)printe(x/10);
    putchar(x%10+'0');
}
template <typename T> inline void print(const T x)
{
    if(x<0)putchar('-'),printe(-x);
    else printe(x);
}
const int maxn=262145;const double PI=acos((double)-1.0);
int n,m;
struct complex
{
    double x,y;
    inline complex operator +(const complex b)const{return (complex){x+b.x,y+b.y};}
    inline complex operator -(const complex b)const{return (complex){x-b.x,y-b.y};}
    inline complex operator *(const complex b)const{return (complex){x*b.x-y*b.y,x*b.y+y*b.x};}
}ax[maxn],ay[maxn],bx[maxn],by[maxn];
int lenth=1,Reverse[maxn];
complex w[maxn];
complex fw[maxn];
inline void init(const int x)
{
    rg int tim=0;
    while(lenth<=x)lenth<<=1,tim++;
    for(rg int i=0;i<lenth;i++)Reverse[i]=(Reverse[i>>1]>>1)|((i&1)<<(tim-1)),w[i]=(complex){cos(i*PI/lenth),sin(i*PI/lenth)},fw[i]=(complex){cos(i*PI/lenth),-sin(i*PI/lenth)};
}
complex W[maxn];
inline void FFT(complex*A,const int fla)
{
    for(rg int i=0;i<lenth;i++)if(i<Reverse[i])swap(A[i],A[Reverse[i]]);
    for(rg int i=1;i<lenth;i<<=1)
    {
    	if(fla==1)
    	{
    		for(rg int k=0;k<i;k++)W[k]=w[lenth/i*k];
    	}
    	else
    	{
    		for(rg int k=0;k<i;k++)W[k]=fw[lenth/i*k];
    	}
        for(rg int j=0;j<lenth;j+=(i<<1))
        {
            for(rg int k=0;k<i;k++)
            {
                const complex x=A[j+k],y=W[k]*A[j+k+i];
                A[j+k]=x+y;
                A[j+k+i]=x-y;
            }
        }
    }
}
int P,p;
int main()
{
    read(n),read(m),read(P);
    p=31624;
    init(n+m);
    for(rg int i=0;i<=n;i++)
    { 
        int x;read(x);
        ax[i].x=x/p,bx[i].x=x%p;
    }
    for(rg int i=0;i<=m;i++)
    {
        int x;read(x);
        ay[i].x=x/p,by[i].x=x%p;
    }
    FFT(ax,1),FFT(bx,1),FFT(ay,1),FFT(by,1);
    for(rg int i=0;i<lenth;i++)
    {
        const complex A=ax[i],B=bx[i],C=ay[i],D=by[i];
        ax[i]=A*C,ay[i]=B*D;
        bx[i]=A*D,by[i]=B*C;
    }
    FFT(ax,-1),FFT(bx,-1),FFT(ay,-1),FFT(by,-1);
    for(rg int i=0;i<=n+m;i++)
    {
        const LL A=ax[i].x/lenth+0.5,B=ay[i].x/lenth+0.5,C=bx[i].x/lenth+0.5,D=by[i].x/lenth+0.5;
        print((A%P*p%P*p%P+B%P+(C%P+D%P)*p%P)%P),putchar(' ');
    }
    return flush(),0;
}
思路三(基于NTT的优化)

经过前面的分析,我们得知:FFT的运算结果NP2\le NP^2,是102310^{23}级别的
我们现在换一个思路,我们选出一些NTT模数(质数)(乘积大于FFT结果的最大值),求出在这些模意义下的值分别数多少,最后通过中国剩余定理(CRT)算出在给定模数的模意义下的值(选的质数一般是:469762049,998244353,1004535809469762049,998244353,1004535809
但是我们发现所有质数的乘积爆long long了,所以不能直接CRT
设一个数的值为AnsAns,选取的三个质数分别为p1,p2,p3p_1,p_2,p_3
我们通过6次DFT,3次IDFT算出在模意义下的值
Ansa1(modp1),Ansa2(modp2),Ansa3(modp3)Ans\equiv a_1\pmod {p_1},Ans\equiv a_2\pmod {p_2},Ans\equiv a_3\pmod {p_3}
根据中国剩余定理我们可以算出Ans=a4(modp1p2)Ans=a_4\pmod{p_1p_2}
Ans=a5p1p2+a4Ans=a_5p_1p_2+a_4,我们已知a4a_4,如果能求出a5a_5就能求出Ans的值
我们发现因为Ansa3(modp3)Ans\equiv a_3\pmod {p_3}
所以a5p1p2a3a4(modp3)a_5p_1p_2\equiv a_3-a_4\pmod {p_3}
就能推出a5(a3a4)p11p21(modp3)a_5\equiv (a_3-a_4)p_1^{-1}p_2^{-1}\pmod {p_3}
然后直接计算就好了
代码也非常好写(这份代码很不注重常数,只注重好写)

#include<cstdio>
#include<cctype>
#include<cstring>
#include<cmath>
namespace fast_IO
{
	const int IN_LEN=10000000,OUT_LEN=10000000;
	char ibuf[IN_LEN],obuf[OUT_LEN],*ih=ibuf+IN_LEN,*oh=obuf,*lastin=ibuf+IN_LEN,*lastout=obuf+OUT_LEN-1;
	inline char getchar_(){return (ih==lastin)&&(lastin=(ih=ibuf)+fread(ibuf,1,IN_LEN,stdin),ih==lastin)?EOF:*ih++;}
	inline void putchar_(const char x){if(oh==lastout)fwrite(obuf,1,oh-obuf,stdout),oh=obuf;*oh++=x;}
	inline void flush(){fwrite(obuf,1,oh-obuf,stdout);}
}
using namespace fast_IO;
#define getchar() getchar_()
#define putchar(x) putchar_((x))
typedef long long LL;
#define rg register
template <typename T> inline T max(const T a,const T b){return a>b?a:b;}
template <typename T> inline T min(const T a,const T b){return a<b?a:b;}
template <typename T> inline T mind(T&a,const T b){a=a<b?a:b;}
template <typename T> inline T maxd(T&a,const T b){a=a>b?a:b;}
template <typename T> inline T abs(const T a){return a>0?a:-a;}
template <typename T> inline void swap(T&a,T&b){T c=a;a=b;b=c;}
template <typename T> inline void swap(T*a,T*b){T c=a;a=b;b=c;}
template <typename T> inline T gcd(const T a,const T b){if(!b)return a;return gcd(b,a%b);}
template <typename T> inline T square(const T x){return x*x;};
template <typename T> inline void read(T&x)
{
    char cu=getchar();x=0;bool fla=0;
    while(!isdigit(cu)){if(cu=='-')fla=1;cu=getchar();}
    while(isdigit(cu))x=x*10+cu-'0',cu=getchar();
    if(fla)x=-x;  
}
template <typename T> void printe(const T x)
{
    if(x>=10)printe(x/10);
    putchar(x%10+'0');
}
template <typename T> inline void print(const T x)
{
    if(x<0)putchar('-'),printe(-x);
    else printe(x);
}
const int maxn=262145;
int n,m;
struct Ntt
{
	LL mod,a[maxn],b[maxn];;
	inline LL pow(LL x,LL y)
	{
		rg LL res=1;
		for(;y;y>>=1,x=x*x%mod)if(y&1)res=res*x%mod;
		return res;
	}
	int lenth,Reverse[maxn];
	inline void init(const int x)
	{
		rg int tim=0;lenth=1;
		while(lenth<=x)lenth<<=1,tim++;
		for(rg int i=0;i<lenth;i++)Reverse[i]=(Reverse[i>>1]>>1)|((i&1)<<(tim-1));
	}
	inline void NTT(LL*A,const int fla)
	{
		for(rg int i=0;i<lenth;i++)if(i<Reverse[i])swap(A[i],A[Reverse[i]]);
		for(rg int i=1;i<lenth;i<<=1)
		{
			LL w=pow(3,(mod-1)/i/2);
			if(fla==-1)w=pow(w,mod-2);
			for(rg int j=0;j<lenth;j+=(i<<1))
			{
				LL K=1;
				for(rg int k=0;k<i;k++,K=K*w%mod)
				{
					const LL x=A[j+k],y=A[j+k+i]*K%mod;
					A[j+k]=(x+y)%mod;
					A[j+k+i]=(mod+x-y)%mod;
				}
			}
		}
		if(fla==-1)
		{
			const int inv=pow(lenth,mod-2);
			for(rg int i=0;i<lenth;i++)A[i]=A[i]*inv%mod;
		}	
	}
}Q[3];
LL EXgcd(const LL a,const LL b,LL &x,LL &y)  
{  
    if(!b)
    {
		x=1,y=0;
        return a;  
    }
    const LL res=EXgcd(b,a%b,y,x);
    y-=a/b*x;
    return res;
}
inline LL msc(LL a,LL b,LL mod)
{
    LL v=(a*b-(LL)((long double)a/mod*b+1e-8)*mod);
    return v<0?v+mod:v;
}
int N,a[3],p[3];
LL CRT()
{  
    LL P=1,sum=0;  
    for(rg int i=1;i<=N;i++)P*=p[i];
    for(rg int i=1;i<=N;i++)  
	{
    	const LL m=P/p[i];
    	LL x,y;
    	EXgcd(p[i],m,x,y);
    	sum=(sum+msc(msc(y,m,P),a[i],P))%P;
	}
	return sum;
}
int P;
int main()
{
	read(n),read(m),read(P);
	Q[0].mod=469762049,Q[0].init(n+m);
	Q[1].mod=998244353,Q[1].init(n+m);
	Q[2].mod=1004535809,Q[2].init(n+m);
	for(rg int i=0;i<=n;i++)read(Q[0].a[i]),Q[2].a[i]=Q[1].a[i]=Q[0].a[i];
	for(rg int i=0;i<=m;i++)read(Q[0].b[i]),Q[2].b[i]=Q[1].b[i]=Q[0].b[i];
	Q[0].NTT(Q[0].a,1),Q[0].NTT(Q[0].b,1);
	Q[1].NTT(Q[1].a,1),Q[1].NTT(Q[1].b,1);
	Q[2].NTT(Q[2].a,1),Q[2].NTT(Q[2].b,1);
	for(rg int i=0;i<Q[0].lenth;i++)
		Q[0].a[i]=(LL)Q[0].a[i]*Q[0].b[i]%Q[0].mod,
		Q[1].a[i]=(LL)Q[1].a[i]*Q[1].b[i]%Q[1].mod,
		Q[2].a[i]=(LL)Q[2].a[i]*Q[2].b[i]%Q[2].mod;
	Q[0].NTT(Q[0].a,-1);
	Q[1].NTT(Q[1].a,-1);
	Q[2].NTT(Q[2].a,-1);
	N=2,p[1]=Q[0].mod,p[2]=Q[1].mod;
	const int INV=Q[2].pow(Q[0].mod,Q[2].mod-2)*Q[2].pow(Q[1].mod,Q[2].mod-2)%Q[2].mod;
	for(rg int i=0;i<=n+m;i++)
	{
		a[1]=Q[0].a[i],a[2]=Q[1].a[i];
		const LL ans1=CRT();
		const LL ans2=((Q[2].a[i]-ans1)%Q[2].mod+Q[2].mod)%Q[2].mod*INV%Q[2].mod;
		print((ans2*Q[0].mod%P*Q[1].mod%P+ans1)%P),putchar(' ');
	}
	return flush(),0;
}
思路四(思路二的进阶版)

容易发现,我们可以把那个pp的次数为11的项直接合并
这样就可以从调用8次DFT优化到调用7次
这里我就不另贴代码了
(另外,优化到这里,这个算法的速度依然很慢,中国剩余定理常数过大)
另外还可以用多项式的奇技淫巧优化常数
资料参考:毛啸,IOI2016国家集训队论文《再探快速傅里叶变换》
贴出myy给的代码链接
由于7次的DFT/IDFT已经很快了,所以咕咕咕咕咕咕
以后有空闲时间再更吧
现在先贴个链接,是txc写的任意模数 NTT 和 DFT 的优化学习笔记

总结

大概是比较清真的算法,如果推出来就很好记

没有更多推荐了,返回首页