luoguP4512 【模板】多项式除法 NTT+多项式求逆+多项式除法
Code:
#include<bits/stdc++.h>
#define maxn 300000
#define ll long long
#define MOD 998244353
#define setIO(s) freopen(s".in","r",stdin) ,freopen(s".out","w",stdout)
using namespace std;
namespace poly{
#define P 998244352
#define G 3
int rev[maxn];
ll X[maxn],Y[maxn];
void calrev(int lim,int l){ for(int i=1;i<lim;++i)rev[i]=(rev[i>>1]>>1)|((i&1)<<(l-1)); }
ll add(ll a,ll b){ return ((a+=b)%=MOD); }
ll qpow(ll a,ll k){
ll base=1;
for(;k;a=(a*a)%MOD,k>>=1) if(k&1)base=(base*a)%MOD;
return base;
}
void NTT(ll *a,int len,int opt){
for(int i=0;i<len;++i) if(i<rev[i]) swap(a[i],a[rev[i]]);
for(int i=1;i<len;i<<=1){
int step=i<<1;
ll wn=qpow(G,(opt*P/step+P));
for(int j=0;j<len;j+=step){
ll w=1;
for(int k=0;k<i;++k,w=(1ll*w*wn)%MOD){
ll x=a[j+k];
ll y=1ll*w*a[j+k+i]%MOD;
a[j+k]=(x+y)%MOD;
a[j+k+i]=(x-y+MOD)%MOD;
}
}
}
if(opt==-1){
ll r=qpow(len,MOD-2);
for(int i=0;i<len;++i) a[i]=1ll*a[i]*r%MOD;
}
}
void mul(ll *x,ll *y,int lim){
memset(X,0,sizeof(X)),memset(Y,0,sizeof(Y));
for(int i=0;i<(lim>>1);++i) X[i]=x[i],Y[i]=y[i];
NTT(X,lim,1),NTT(Y,lim,1);
for(int i=0;i<lim;++i) X[i]=(ll)X[i]*Y[i]%MOD;
NTT(X,lim,-1);
for(int i=0;i<lim;++i) x[i]=X[i];
}
ll B[3][maxn],C[maxn],D[maxn];
void get_inv(int n,ll *A){
int cur=0,bas=1,lim=2,len=1;
B[cur][0]=qpow(A[0],MOD-2);
calrev(lim,len);
while(bas<=(n<<1)){
cur^=1;
memset(B[cur],0,sizeof(B[cur]));
for(int i=0;i<bas;++i) B[cur][i]=add(B[cur^1][i]<<1,0);
mul(B[cur^1],B[cur^1],lim),mul(B[cur^1],A,lim);
for(int i=0;i<bas;++i) B[cur][i]=add(B[cur][i],MOD-B[cur^1][i]);
bas<<=1,lim<<=1,++len;
if(bas<=(n<<1)) calrev(lim,len);
}
for(int i=0;i<=n;++i) A[i]=B[cur][i];
}
};
ll A[maxn],B[maxn],n,m,lim,len;
ll Ar[maxn],Br[maxn],Dr[maxn];
int main(){
//setIO("input");
scanf("%d%d",&n,&m);
for(int i=0;i<=n;++i) scanf("%d",&A[i]),Ar[i]=A[i];
for(int i=0;i<=m;++i) scanf("%d",&B[i]),Br[i]=B[i];
reverse(Ar,Ar+n+1),reverse(Br,Br+m+1);
for(int i=n-m+2;i<=max(n,m);++i) Br[i]=Ar[i]=0;
poly::get_inv(n-m+1,Br); //Br的逆
lim=1,len=0;
while(lim<=n-m+1+n-m+1) lim<<=1,++len;
poly::calrev(lim,len), poly::mul(Ar,Br,lim);
for(int i=n-m;i>=0;--i) printf("%lld ",Ar[i]),Dr[n-m-i]=Ar[i];
lim=1,len=0;
while(lim<=n*2) lim<<=1,++len;
poly::calrev(lim,len),poly::mul(B,Dr,lim);
printf("\n");
for(int i=0;i<=m-1;++i) {
ll h=(A[i]-B[i]+MOD)%MOD;
printf("%lld ",h);
}
return 0;
}