引入
对快速傅里叶变换(FFT)的缺点进行了优化。
在计算多项式乘法(卷积)时,FFT设计三角函数、复数等很多恶心的东西,有着最大的缺点:精度问题,而在很多题目中往往需要进行取模,要求精度很高,FFT就不行了。
于是就有了快速数论变换
原根
FFT之所以可以实现,是利用了单位复根
ω
\omega
ω的周期性质,
ω
n
n
=
1
,
ω
n
k
=
ω
n
k
+
n
\omega_n^n=1,\omega_n^k=\omega_n^{k+n}
ωnn=1,ωnk=ωnk+n;
通过这个性质,可以把FFT后续所有步骤全部推导出来。
NTT由于需要取模,根据模数,我们可以重新定义一个类似于单位复根的东西,使它的幂有周期性,那就是原根
原根:
ω
n
n
≡
1
(
m
o
d
p
)
\omega_n^n\equiv 1(mod\space p)
ωnn≡1(mod p),且没有
(
k
=
1
,
2
,
3...
,
n
−
1
)
(k=1,2,3...,n-1)
(k=1,2,3...,n−1)
ω
n
k
≡
1
(
m
o
d
p
)
\omega_n^k\equiv 1(mod\space p)
ωnk≡1(mod p)。
对于每个质数
p
p
p,令
g
p
−
1
≡
1
(
m
o
d
p
)
g^{p-1}\equiv 1(mod\space p)
gp−1≡1(mod p),且
g
k
m
o
d
p
g^k\space mod\space p
gk mod p都不为1,
(
1
≤
k
≤
p
−
2
)
(1≤k≤p-2)
(1≤k≤p−2)
则
g
p
−
1
n
g^{\frac {p-1} n}
gnp−1就可以作为原根
ω
n
\omega_n
ωn,满足FFT中单位复根的一切性质。
将FFT中所有单位复根换位原根,就可以实现NTT了。
代码
//UOJ34
#include<cstdio>
#include<algorithm>
using namespace std;
const int MAXN=400005,MOD=998244353,G=3;
int PowMod(int a,int b)
{
int res=1;
for(;b;b>>=1,a=1LL*a*a%MOD)
if(b&1)
res=1LL*res*a%MOD;
return res;
}
void NTT(int A[],int n,int mode)
{
for(int i=0,j=0;i<n;i++)
{
if(i<j)swap(A[i],A[j]);
int k=n>>1;
while(k&j)
j^=k,k>>=1;
j^=k;
}
for(int i=1;i<n;i<<=1)
{
int w1=PowMod(ROOT,(MOD-1)/(i<<1));
if(mode==-1)
w1=PowMod(w1,MOD-2);
for(int j=0;j<n;j+=(i<<1))
for(int l=j,r=j+i,w=1;l<j+i;l++,r++,w=1LL*w*w1%MOD)
{
int tmp=1LL*A[r]*w%MOD;
A[r]=(A[l]-tmp+MOD)%MOD;
A[l]=(A[l]+tmp)%MOD;
}
}
if(mode==-1)
{
int invn=PowMod(n,MOD-2);
for(int i=0;i<n;i++)
A[i]=1LL*A[i]*invn%MOD;
}
}
void Multiply(const int A[],int len1,const int B[],int len2,int C[])
{
static int A0[MAXN*3],B0[MAXN*3];
int len=1;
for(;len<len1+len2-1;len<<=1);
for(int i=0;i<len1;i++)A0[i]=A[i];
for(int i=len1;i<len;i++)A0[i]=0;
for(int i=0;i<len2;i++)B0[i]=B[i];
for(int i=len2;i<len;i++)B0[i]=0;
NTT(A0,len,1);NTT(B0,len,1);
for(int i=0;i<len;i++)
A0[i]=1LL*A0[i]*B0[i]%MOD;
NTT(A0,len,-1);
for(int i=0;i<len1+len2-1;i++)C[i]=A0[i];
}
int A[MAXN],B[MAXN];
int main()
{
int n,m;
scanf("%d%d",&n,&m);
for(int i=0;i<=n;i++)
scanf("%d",A+i);
for(int i=0;i<=m;i++)
scanf("%d",B+i);
Multiply(A,n+1,B,m+1,A);
for(int i=0;i<n+m;i++)
printf("%d ",A[i]);
printf("%d\n",A[n+m]);
return 0;
}