证明什么的这个博主写的很详细了
这里只给出代码
多项式乘,除(取模),求逆元:
例题–luoguP4512
#include<iostream>
#include<cstdio>
#include<algorithm>
#include<cstring>
#include<cmath>
#define maxn 1000005
using namespace std;
inline int rd(){
int x=0,f=1;char c=' ';
while(c<'0' || c>'9') f=c=='-'?-1:1,c=getchar();
while(c<='9' && c>='0') x=x*10+c-'0',c=getchar();
return x*f;
}
int n,m,rev[maxn],a[maxn],b[maxn],c[maxn],tmp[maxn];
const int mod=998244353;
inline int qpow(int x,int k){
int ret=1;
while(k){
if(k&1) ret=1LL*ret*x%mod;
x=1LL*x*x%mod; k>>=1;
} return ret%mod;
}
inline void NTT(int *F,int type,int limit){
for(int i=0;i<limit;i++) if(i<rev[i]) swap(F[i],F[rev[i]]);
for(int mid=1;mid<limit;mid<<=1){
int Wn=qpow(3,type==1?(mod-1)/(mid<<1):(mod-1-(mod-1)/(mid<<1)));
for(int r=mid<<1,j=0;j<limit;j+=r){
int w=1;
for(int k=0;k<mid;k++,w=1LL*w*Wn%mod){
int x=F[j+k],y=1LL*w*F[j+k+mid]%mod;
F[j+k]=(x+y)%mod;F[j+mid+k]=(x-y+mod)%mod;
}
}
}
if(type==-1) {
int inv=qpow(limit,mod-2);
for(int i=0;i<limit;i++) F[i]=1LL*F[i]*inv%mod;
}
}
inline void MUL(int *a,int *b,int ll){//乘
int lim=0,limit=1; while(limit<=2*ll) limit<<=1,++lim;
for(int i=0;i<limit;i++)
rev[i]=(rev[i>>1]>>1)|((i&1)<<(lim-1));
NTT(a,1,limit); NTT(b,1,limit);
for(int i=0;i<limit;i++) a[i]=1LL*a[i]*b[i]%mod;
NTT(a,-1,limit);
}
inline void INV(int *a,int *b,int limit){//逆元
if(limit==1) {
b[0]=qpow(a[0],mod-2);
return;
}
INV(a,b,limit>>1);
memcpy(tmp,a,sizeof(int)*limit); memset(tmp+limit,0,sizeof(int)*limit);
int lim=0; while(!(limit>>lim&1)) lim++;
for(int i=0;i<(limit<<1);i++)
rev[i]=(rev[i>>1]>>1)|((i&1)<<lim);
NTT(tmp,1,limit<<1); NTT(b,1,limit<<1);
for(int i=0;i<(limit<<1);i++)
tmp[i]=(2LL*b[i]%mod+mod-1LL*tmp[i]*b[i]%mod*b[i]%mod)%mod;
NTT(tmp,-1,limit<<1);
memcpy(b,tmp,sizeof(int)*limit); memset(b+limit,0,sizeof(int)*limit);
}
inline void DIV(int *a,int n,int *b,int m){//除+取模
static int tmp[maxn],A[maxn],B[maxn];
for(int i=0;i<m;i++) c[i]=b[m-i-1];
for(int i=0;i<n;i++) A[i]=a[n-i-1];
int limit=1,d=n-m+1; while(limit<2*d) limit<<=1;
for(int i=n;i<limit;i++) A[i]=0;
for(int i=d;i<limit;i++) c[i]=0;
for(int i=0;i<limit;i++) B[i]=0;
INV(c,B,limit);
for(int i=d;i<limit;i++) B[i]=0;
MUL(A,B,max(n,d));
for(int i=d;i<=n<<1;i++) A[i]=0;
for(int i=0;i<d;i++)
if(i>d-i-1) swap(A[i],A[d-i-1]);
for(int i=0;i<m;i++) c[i]=b[i];
for(int i=0;i<d;i++) b[i]=A[i];
MUL(c,A,max(d,m));
for(int i=0;i<n;i++) a[i]=(a[i]+mod-c[i])%mod;
}
int main(){
n=rd(); m=rd();
for(int i=0;i<=n;i++) a[i]=rd();
for(int i=0;i<=m;i++) b[i]=rd();
DIV(a,n+1,b,m+1);
for(int i=0;i<n-m+1;i++) printf("%d ",b[i]);puts("");
for(int i=0;i<m;i++) printf("%d ",a[i]);
return 0;
}
多项式开方:
inline void SQRT(int *a,int *b,int limit){
if(limit==1){
b[0]=1; return;
}
SQRT(a,b,limit>>1);
memset(invb,0,sizeof(int)*limit);
INV(b,invb,limit);
int lim=0; while(!(limit>>lim&1)) ++lim;
for(int i=0;i<(limit<<1);i++)
rev[i]=(rev[i>>1]>>1)|((i&1)<<lim);
memcpy(tmp,a,sizeof(int)*limit); memset(tmp+limit,0,sizeof(int)*limit);
NTT(invb,1,limit<<1); NTT(tmp,1,limit<<1);
for(int i=0;i<(limit<<1);i++) tmp[i]=1LL*tmp[i]*inv2%mod*invb[i]%mod;
NTT(tmp,-1,limit<<1);
for(int i=0;i<limit;i++) b[i]=(1LL*b[i]*inv2%mod+tmp[i])%mod;
}
例题–CF438E
这是个生成函数的题,有一个式子,推出来之后需要多项式开方和逆元,具体式子可以看别人的题解,我困了
#include<iostream>
#include<cstdio>
#include<cstring>
#include<algorithm>
#include<cmath>
#define maxn 300005
#define LL long long
using namespace std;
inline int rd(){
int x=0,f=1;char c=' ';
while(c<'0' || c>'9') f=c=='-'?-1:1,c=getchar();
while(c<='9' && c>='0') x=x*10+c-'0',c=getchar();
return x*f;
}
const int mod=998244353;
int n,m,rev[maxn];
int a[maxn],b[maxn],c[maxn],invb[maxn],tmp[maxn],inv2;
inline int qpow(int x,int k){
int ret=1;
while(k){
if(k&1) ret=1LL*ret*x%mod;
x=1LL*x*x%mod; k>>=1;
} return ret%mod;
}
inline void NTT(int *F,int type,int limit){
for(int i=0;i<limit;i++)
if(i<rev[i]) swap(F[i],F[rev[i]]);
for(int mid=1;mid<limit;mid<<=1){
int Wn=qpow(3,type==1?(mod-1)/(mid<<1):(mod-1-(mod-1)/(mid<<1)));
for(int r=mid<<1,j=0;j<limit;j+=r){
int w=1;
for(int k=0;k<mid;k++,w=1LL*w*Wn%mod){
int x=F[j+k],y=1LL*w*F[j+mid+k]%mod;
F[j+k]=(x+y)%mod,F[j+k+mid]=(x-y+mod)%mod;
}
}
}
if(type==-1){
int inv=qpow(limit,mod-2);
for(int i=0;i<limit;i++) F[i]=(1LL*F[i]*inv)%mod;
}
}
inline void INV(int *a,int *b,int limit){
if(limit==1) {
b[0]=qpow(a[0],mod-2);
return;
}
INV(a,b,limit>>1);
memcpy(tmp,a,sizeof(int)*limit); memset(tmp+limit,0,sizeof(int)*limit);
int lim=0; while(!(limit>>lim&1)) ++lim;
for(int i=0;i<(limit<<1);i++)
rev[i]=(rev[i>>1]>>1)|((i&1)<<lim);
NTT(tmp,1,limit<<1); NTT(b,1,limit<<1);
for(int i=0;i<(limit<<1);i++)
tmp[i]=(2LL*b[i]+mod-1LL*tmp[i]*b[i]%mod*b[i]%mod)%mod;
NTT(tmp,-1,limit<<1);
memcpy(b,tmp,sizeof(int)*limit); memset(b+limit,0,sizeof(int)*limit);
}
inline void SQRT(int *a,int *b,int limit){
if(limit==1){
b[0]=1; return;
}
SQRT(a,b,limit>>1);
memset(invb,0,sizeof(int)*limit);
INV(b,invb,limit);
int lim=0; while(!(limit>>lim&1)) ++lim;
for(int i=0;i<(limit<<1);i++)
rev[i]=(rev[i>>1]>>1)|((i&1)<<lim);
memcpy(tmp,a,sizeof(int)*limit); memset(tmp+limit,0,sizeof(int)*limit);
NTT(invb,1,limit<<1); NTT(tmp,1,limit<<1);
for(int i=0;i<(limit<<1);i++) tmp[i]=1LL*tmp[i]*inv2%mod*invb[i]%mod;
NTT(tmp,-1,limit<<1);
for(int i=0;i<limit;i++) b[i]=(1LL*b[i]*inv2%mod+tmp[i])%mod;
}
int main(){
n=rd(); m=rd(); inv2=qpow(2,mod-2);
for(int i=1;i<=n;i++){
int x=rd(); a[x]=mod-4;
}
int limit=1; a[0]=1;
while(limit<=m) limit<<=1;
SQRT(a,b,limit);
b[0]=(b[0]+1)%mod;
INV(b,c,limit);
for(int i=1;i<=m;i++) printf("%lld\n",1LL*c[i]*2LL%mod);
return 0;
}
还有一个有点毒的题luoguP4239
多项式求逆加强版,就是模数是
1
e
9
+
7
1e9+7
1e9+7,我用了三模数法,但因为负数不可以这样算,要一边算一边
C
R
T
CRT
CRT合并,调了一晚上···
#include<iostream>
#include<cstdio>
#include<algorithm>
#include<cstring>
#include<cmath>
#define maxn 400005
#define LL long long
using namespace std;
inline int rd(){
int x=0,f=1;char c=' ';
while(c<'0' || c>'9') f=c=='-'?-1:1,c=getchar();
while(c<='9' && c>='0') x=x*10+c-'0',c=getchar();
return x*f;
}
const int p=1e9+7,mod1=998244353,mod2=1004535809,mod3=469762049;
const LL M=1LL*mod1*mod2;
int n,rev[maxn];
LL sta[maxn],stb[maxn],B[3][maxn],temp[3][maxn];
LL INV1,INV2,INV3;
inline LL mul(LL x,LL y,LL MOD){//real fast mul
LL tmp=(x*y-(LL)((long double)x/MOD*y+1e-8)*MOD);
return tmp<0?tmp+MOD:tmp;
}
inline int qpow(int x,int k,int MOD){//fake fast pow
int ret=1;
while(k){
if(k&1) ret=1LL*ret*x%MOD;
x=1LL*x*x%MOD; k>>=1;
} return ret%MOD;
}
inline void NTT(LL *F,int type,int limit,int MOD){
for(int i=0;i<limit;i++)
if(i<rev[i]) swap(F[i],F[rev[i]]);
for(int mid=1;mid<limit;mid<<=1){
int Wn=qpow(3,type==1?(MOD-1)/(mid<<1):(MOD-1-(MOD-1)/(mid<<1)),MOD);
for(int r=mid<<1,j=0;j<limit;j+=r){
int w=1;
for(int k=0;k<mid;k++,w=1LL*w*Wn%MOD){
int x=F[j+k],y=1LL*w*F[j+mid+k]%MOD;
F[j+k]=(x+y)%MOD,F[j+mid+k]=(x-y+MOD)%MOD;
}
}
}
if(type==-1){
int inv=qpow(limit,MOD-2,MOD);
for(int i=0;i<limit;i++) F[i]=1LL*F[i]*inv%MOD;
}
}
inline LL CRT(LL a1,LL a2,LL a3){
LL x=mul(1LL*a1*mod2%M,INV1,M)%M;
x+=mul(1LL*a2*mod1%M,INV2,M)%M;
a2=x%M; LL y=(a3-a2%mod3+mod3)%mod3*INV3%mod3;
return (M%p*y%p+a2%p)%p;
}
inline void solve(int x,int ll,int MOD){
NTT(temp[x],1,ll,MOD); NTT(B[x],1,ll,MOD);
for(int i=0;i<ll;i++)
B[x][i]=temp[x][i]*B[x][i]%MOD;
NTT(B[x],-1,ll,MOD);
}
inline void INV(LL *a,LL *b,int limit){
if(limit==1){
b[0]=qpow(a[0],p-2,p);
return;
}
INV(a,b,limit>>1);
for(int j=0;j<3;j++){
memcpy(temp[j],a,sizeof(LL)*limit); memset(temp[j]+limit,0,sizeof(LL)*limit);
memcpy(B[j],b,sizeof(LL)*limit);
}
int lim=0; while(!(limit>>lim&1)) ++lim;
for(int i=0;i<(limit<<1);i++)
rev[i]=(rev[i>>1]>>1)|((i&1)<<lim);
solve(0,limit<<1,mod1); solve(1,limit<<1,mod2); solve(2,limit<<1,mod3);
for(int i=0;i<limit;i++)
B[0][i]=B[1][i]=B[2][i]=(p-CRT(B[0][i],B[1][i],B[2][i]))%p;
B[0][0]=B[1][0]=B[2][0]=(B[0][0]+2)%p;
for(int j=0;j<3;j++){
memcpy(temp[j],b,sizeof(LL)*limit); memset(temp[j]+limit,0,sizeof(LL)*limit);
memset(B[j]+limit,0,sizeof(LL)*limit);
}
solve(0,limit<<1,mod1); solve(1,limit<<1,mod2); solve(2,limit<<1,mod3);
for(int i=0;i<limit;i++) b[i]=CRT(B[0][i],B[1][i],B[2][i]);
}
int main(){
n=rd();
for(int i=0;i<n;i++) sta[i]=rd();
INV1=qpow(mod2,mod1-2,mod1),INV2=qpow(mod1,mod2-2,mod2),INV3=qpow(M%mod3,mod3-2,mod3);
int limit=1; while(limit<n) limit<<=1;
INV(sta,stb,limit);
for(int i=0;i<n;i++) printf("%lld ",stb[i]);
return 0;
}
多项式求对数
luoguP4725
要用到求导和积分,但求导和积分在多项式中其实都很简单
这里搬了luogu的某题解
下面放上代码
#include<iostream>
#include<cstdio>
#include<algorithm>
#include<cstring>
#include<cmath>
#define maxn 400005
#define LL long long
using namespace std;
const int mod=998244353;
inline int rd(){
int x=0,f=1;char c=getchar();
while(c<'0' || c>'9') f=c=='-'?-1:1,c=getchar();
while(c<='9' && c>='0') x=x*10+c-'0',c=getchar();
return x*f;
}
int n,rev[maxn],a[maxn],b[maxn],tmp[maxn],c[maxn];
inline int qpow(int x,int k){
int ret=1;
while(k){
if(k&1) ret=1LL*ret*x%mod;
x=1LL*x*x%mod; k>>=1;
} return ret;
}
inline void NTT(int *F,int type,int limit){
for(int i=0;i<limit;i++) if(i<rev[i]) swap(F[i],F[rev[i]]);
for(int mid=1;mid<limit;mid<<=1){
int Wn=qpow(3,type==1?(mod-1)/(mid<<1):(mod-1-(mod-1)/(mid<<1)));
for(int r=mid<<1,j=0;j<limit;j+=r){
int w=1;
for(int k=0;k<mid;k++,w=1LL*w*Wn%mod){
int x=F[j+k],y=1LL*w*F[j+mid+k]%mod;
F[j+k]=(x+y)%mod,F[j+k+mid]=(mod+x-y)%mod;
}
}
}
if(type==-1){
int Inv=qpow(limit,mod-2);
for(int i=0;i<limit;i++) F[i]=1LL*F[i]*Inv%mod;
}
}
inline void derivation(int n){
for(int i=1;i<n;i++) b[i-1]=1LL*a[i]*i%mod; b[n-1]=0;
}
inline void integral(int n){
for(int i=1;i<n;i++) a[i]=1LL*b[i-1]*qpow(i,mod-2)%mod; a[0]=0;
}
void INV(int *a,int *b,int limit){
if(limit==1){
b[0]=qpow(a[0],mod-2);
return;
}
INV(a,b,limit>>1);
memcpy(tmp,a,sizeof(int)*limit); memset(tmp+limit,0,sizeof(int)*limit);
int lim=0; while(!(limit>>lim&1)) ++lim;
for(int i=0;i<limit<<1;i++)
rev[i]=(rev[i>>1]>>1)|((i&1)<<lim);
NTT(tmp,1,limit<<1); NTT(b,1,limit<<1);
for(int i=0;i<limit<<1;i++)
tmp[i]=(2LL*b[i]%mod+mod-1LL*tmp[i]*b[i]%mod*b[i]%mod)%mod;
NTT(tmp,-1,limit<<1);
memcpy(b,tmp,sizeof(int)*limit); memset(b+limit,0,sizeof(int)*limit);
}
inline void MUL(int *a,int *b,int limit){
int lim=0; while(!(limit>>lim&1)) ++lim;
for(int i=0;i<limit;i++)
rev[i]=(rev[i>>1]>>1)|((i&1)<<(lim-1));
NTT(a,1,limit); NTT(b,1,limit);
for(int i=0;i<limit;i++) a[i]=1LL*a[i]*b[i]%mod;
NTT(a,-1,limit);
}
int main(){
n=rd();
for(int i=0;i<n;i++) a[i]=rd();
int limit=1; while(limit<=n) limit<<=1;
derivation(n); INV(a,c,limit); MUL(b,c,limit<<1);//注意这要<<1
integral(n);
for(int i=0;i<n;i++) printf("%d ",a[i]);
return 0;
}