两个多项式相乘,若暴力计算,进行系数乘法,时间复杂度为
借助快速傅立叶变换FFT或NTT可以降至
原理大家可以看--> 点这里
多项式乘法本质上是对每一项进行精细的计算,所以我们也可以用它来求高精度乘法
以下是FFT模版
const int N=4e6+10;
struct Cmp{
double x,y;
Cmp(double xx=0,double yy=0){
x=xx;
y=yy;
}
};
Cmp operator + (Cmp& a,Cmp& b){
return Cmp(a.x+b.x,a.y+b.y);
}
Cmp operator - (Cmp& a,Cmp& b){
return Cmp(a.x-b.x,a.y-b.y);
}
Cmp operator * (Cmp& a,Cmp& b){
return Cmp(a.x*b.x-a.y*b.y,a.x*b.y+a.y*b.x);
}
int R[N];
void FFT(Cmp A[],int n,int op){
for(int i=0;i<n;i++){
if(i<R[i])
swap(A[i],A[R[i]]);
}
for(int m=2;m<=n;m<<=1){
Cmp w1({cos(2*PI/m),sin(2*PI/m)*op});
for(int i=0;i<n;i+=m){
Cmp wk({1,0});
for(int j=0;j<m/2;j++){
Cmp x=A[i+j],y=A[i+j+m/2]*wk;
A[i+j]=x+y;
A[i+j+m/2]=x-y;
wk=wk*w1;
}
}
}
}
Cmp A[N],B[N];
void solve(){
int n,m;
cin >> n >> m;
for(int i=0;i<=n;i++){
cin >> A[i].x;
}
for(int i=0;i<=m;i++){
cin >> B[i].x;
}
for(m=m+n,n=1;n<=m;n<<=1);
for(int i=0;i<n;i++){
R[i]=R[i/2]/2+((i&1)?n/2:0);
}
FFT(A,n,1);
FFT(B,n,1);//求点积
for(int i=0;i<n;i++){
A[i]=A[i]*B[i];
}
FFT(A,n,-1);//求系数
for(int i=0;i<=m;i++){
int x=(int)((A[i].x)/n+0.5);
cout << x << ' ';
//printf("%d ",(int)((A[i].x)/n+0.5));
}
}
以下是NTT模版
#define int ll
const int N=4e6+10;
const int P=998244353;
int R[N];
int qpow(int a,int b){
int res=1;
while(b){
if(b&1)
res=(res*a)%P;
b>>=1;
a=(a*a)%P;
}
return res%P;
}
int g=3,gi=332748118;
void NTT(int A[],int n,int op){
for(int i=0;i<n;i++){
if(i<R[i]){
swap(A[i],A[R[i]]);
}
}
for(int i=2;i<=n;i<<=1){
ll g1=qpow(op==1?g:gi,(P-1)/i);
for(int j=0;j<n;j+=i){
ll gk=1;
for(int k=j;k<j+i/2;k++){
ll x=A[k],y=gk*A[k+i/2]%P;
A[k]=(x+y)%P,A[k+i/2]=(x-y+P)%P;
gk=(gk*g1)%P;
}
}
}
}
int A[N],B[N];
void solve(){
int n,m;
cin >> n >> m;
for(int i=0;i<=n;i++){
cin >> A[i];
}
for(int i=0;i<=m;i++){
cin >> B[i];
}
for(m=m+n,n=1;n<=m;n<<=1);
for(int i=0;i<n;i++){
R[i]=R[i/2]/2+((i&1)?n/2:0);
}
NTT(A,n,1);
NTT(B,n,1);
for(int i=0;i<n;i++){
A[i]=(A[i]*B[i])%P;
}
NTT(A,n,-1);
int inv = qpow(n, P - 2);
for(int i=0;i<=m;i++){
cout << A[i]*inv%P << ' ';
}
}
以下是FFT高精度
NTT同样也可以
和普通的高精度乘法一样,需要反转存储原始数据
在运算结束后需要进位和去除前导0
int ans[N];
void solve(){
string s1,s2;
cin >> s1 >> s2;
int n=s1.size()-1,m=s2.size()-1;
for(int i=0;i<=n;i++){
A[i].x=s1[n-i]-'0';
}
for(int i=0;i<=m;i++){
B[i].x=s2[m-i]-'0';
}
for(m=m+n,n=1;n<=m;n<<=1);
FFT(A,n,1);
FFT(B,n,1);//求点积
for(int i=0;i<n;i++){
A[i]=A[i]*B[i];
}
FFT(A,n,-1);//求系数
int k=0;
for(int i=0,t=0;i<n||t;i++){
t+=A[i].x/n+0.5;
ans[k++]=t%10;
t/=10;
}
while(k>1&&ans[k-1]==0)
k--;
for(int i=k-1;i>=0;i--){
cout << ans[i];
}
}