题目描述
传送门
题目大意:给两个次数界为n的多项式,求这两个多项式的乘积,输出前x的0次项到n-1次项的系数 mod 23333333
题解
NTT只能求在FFT模数下的值。对于任意模数的题来说,我们可以选择三个FFT模数分别做NTT,最后用中国剩余定理合并。
一次卷积后每个数可以达到
1023(n∗mod2)
左右,所以我们需要选择三个乘积大于
1023
的FFT模数。合并的时候
1023>264
,所以我们不能直接用中国剩余定理,需要一点小技巧。
ans≡a1 (mod m1)
ans≡a2 (mod m2)
ans≡a3 (mod m3)
先用中国剩余定理合并 a1,a2 ,得到 ans≡A (mod M)
其中 A=a1∗m2∗m2−1+a2∗m1∗m1−1,M=m1∗m2
ans=k∗M+A=x∗m3+a3
k∗M≡a3−A (mod m3)
k≡(a3−A)∗M−1 (mod m3)
求出k后带入 ans=k∗M+A ,然后最后ans对23333333取模即可。
中间做乘法的过程需要用到快速乘,否则会炸long long
代码
#include<iostream>
#include<cstdio>
#include<cstring>
#include<algorithm>
#include<cmath>
#define N 500003
#define mod 23333333
#define LL long long
using namespace std;
LL M[10],f[4][N],g[4][N],ans[N];
int n,m,n1,L,R[N];
LL quickpow(LL num,LL x,LL p)
{
LL base=num%p; LL ans=1;
while (x) {
if (x&1) ans=ans*base%p;
x>>=1;
base=base*base%p;
}
return ans;
}
void NTT(LL a[N],int n,int opt,LL p)
{
for (int i=0;i<n;i++)
if (i>R[i]) swap(a[i],a[R[i]]);
for (int i=1;i<n;i<<=1) {
LL wn=quickpow(3,(p-1)/(i<<1),p);
for (int p1=i<<1,j=0;j<n;j+=p1){
LL w=1;
for (int k=0;k<i;k++,w=w*wn%p){
LL x=a[j+k],y=w*a[j+k+i]%p;
a[j+k]=(x+y)%p; a[j+k+i]=(x-y+p)%p;
}
}
}
if (opt==-1) reverse(a+1,a+n);
}
void solve(LL a[N],LL b[N],LL p)
{
// for (int i=0;i<=n1;i++) a[i]%=p;
// for (int i=0;i<=n1;i++) b[i]%=p;
NTT(a,n1,1,p); NTT(b,n1,1,p);
for (int i=0;i<n1;i++) a[i]=a[i]*b[i]%p;
NTT(a,n1,-1,p);
LL rev=quickpow(n1,p-2,p);
for (int i=0;i<n1;i++) a[i]=a[i]*rev%p;
}
void exgcd(LL a,LL b,LL &x,LL &y)
{
if (!b) {
x=1; y=0;
return;
}
exgcd(b,a%b,x,y);
LL t=y;
y=x-(a/b)*y;
x=t;
}
LL mul(LL num,LL x,LL p)
{
LL ans=0; LL base=num%p;
while (x) {
if (x&1) ans=(ans+base)%p;
x>>=1;
base=(base+base)%p;
}
return ans;
}
LL china(LL a1,LL a2,LL a3)
{
LL MM=M[1]*M[2];
LL x,x1; x=quickpow(M[2],M[1]-2,M[1]); x1=quickpow(M[1],M[2]-2,M[2]);
LL A=(mul(a1*M[2]%MM,x%MM,MM)+mul(a2*M[1]%MM,x1%MM,MM))%MM;
LL k=(a3-A)%M[3]*quickpow(MM,M[3]-2,M[3])%M[3];
k=(k%M[3]+M[3])%M[3];
return ((k%mod)*(MM%mod)%mod+A)%mod;
}
int main()
{
freopen("annona_squamosa.in","r",stdin);
freopen("annona_squamosa.out","w",stdout);
scanf("%d",&n);
M[1]=998244353,M[2]=1004535809,M[3]=469762049;
for (int i=0;i<n;i++) scanf("%lld",&f[1][i]),f[2][i]=f[3][i]=f[1][i];
for (int i=0;i<n;i++) scanf("%lld",&g[1][i]),g[2][i]=g[3][i]=g[1][i];
m=2*n;
for (n1=1;n1<=m;n1<<=1) L++;
for (int i=0;i<=n1;i++) R[i]=(R[i>>1]>>1)|((i&1)<<(L-1));
for (int i=1;i<=3;i++)
solve(f[i],g[i],M[i]);
for (int i=0;i<n;i++) printf("%lld\n",china(f[1][i],f[2][i],f[3][i]));
}