之前看过很多FFT和NTT的算法思路,一直没有能实现这两种算法。因为最近的研究需要完成NTT算法进行多项式加速操作。
找的很多模版都有很多错误,或者可读性太差,没法修改,最近看到算法模版很不错。
FFT模版:
#include <complex>
#include <cmath>
#include <vector>
#include<iostream>
using namespace std;
const long double PI = acos(0.0) * 2.0;
typedef complex<double> CD;
// Cooley-Tukey的FFT算法,迭代实现。inverse = false时计算逆FFT
inline void FFT(vector<CD> &a, bool inverse) {
int n = a.size();
// 原地快速bit reversal
for(int i = 0, j = 0; i < n; i++) {
if(j > i) swap(a[i], a[j]);
int k = n;
while(j & (k >>= 1)) j &= ~k;
j |= k;
}
double pi = inverse ? -PI : PI;
for(int step = 1; step < n; step <<= 1) {
// 把每相邻两个“step点DFT”通过一系列蝴蝶操作合并为一个“2*step点DFT”
double alpha = pi / step;
// 为求高效,我们并不是依次执行各个完整的DFT合并,而是枚举下标k
// 对于一个下标k,执行所有DFT合并中该下标对应的蝴蝶操作,即通过E[k]和O[k]计算X[k]
// 蝴蝶操作参考:http://en.wikipedia.org/wiki/Butterfly_diagram
for(int k = 0; k < step; k++) {
// 计算omega^k. 这个方法效率低,但如果用每次乘omega的方法递推会有精度问题。
// 有更快更精确的递推方法,为了清晰起见这里略去
CD omegak = exp(CD(0, alpha*k));
for(int Ek = k; Ek < n; Ek += step << 1) { // Ek是某次DFT合并中E[k]在原始序列中的下标
int Ok = Ek + step; // Ok是该DFT合并中O[k]在原始序列中的下标
CD t = omegak * a[Ok]; // 蝴蝶操作:x1 * omega^k
a[Ok] = a[Ek] - t; // 蝴蝶操作:y1 = x0 - t
a[Ek] += t; // 蝴蝶操作:y0 = x0 + t
}
}
}
if(inverse)
for(int i = 0; i < n; i++) a[i] /= n;
}
// 用FFT实现的快速多项式乘法
inline vector<double> operator * (const vector<double>& v1, const vector<double>& v2) {
int s1 = v1.size(), s2 = v2.size(), S = 2;
while(S < s1 + s2) S <<= 1;
vector<CD> a(S,0), b(S,0); // 把FFT的输入长度补成2的幂,不小于v1和v2的长度之和
for(int i = 0; i < s1; i++) a[i] = v1[i];
FFT(a, false);
for(int i = 0; i < s2; i++) b[i] = v2[i];
FFT(b, false);
for(int i = 0; i < S; i++) a[i] *= b[i];
FFT(a, true);
vector<double> res(s1 + s2 - 1);
for(int i = 0; i < s1 + s2 - 1; i++) res[i] = a[i].real(); // 虚部均为0
return res;
}
/// 题目相关
#include<cstdio>
#include<cstring>
vector<double>a, b, ans;
int main()
{
int n, m;
scanf("%d%d", &n, &m);
for(int i = 1;i <= n+1;i++)
{
double tmp;
scanf("%lf", &tmp);
a.push_back(tmp);
}
for(int i = 1;i <= m+1;i++)
{
double tmp;
scanf("%lf", &tmp);
b.push_back(tmp);
}
ans = a * b;
for(int i = 0;i <= n+m;i++)
printf("%d ", (int)(ans[i] + 0.5));
return 0;
}
NTT模版:
#include<bits/stdc++.h>
#define rg register
using namespace std;
typedef long long ll;
const int mod=998244353,g=3;
const int maxn = 1e6 + 10;
inline int qpow(int x,int k)
{
int ans=1;
while(k)
{
if(k&1)
ans=(ll)ans*x%mod;
x=(ll)x*x%mod,k>>=1;
}
return ans;
}
inline int module(int x,int y)
{
x+=y;
if(x>=mod)
x-=mod;
return x;
}
int rev[4*maxn];
inline void NTT(int*t,int lim,int type)
{
for(rg int i=0;i<lim;++i)
if(i<rev[i])
swap(t[i],t[rev[i]]);
for(rg int i=1;i<lim;i<<=1)
{
int gn=qpow(g,(mod-1)/(i<<1));
if(type==-1)
gn=qpow(gn,mod-2);
for(rg int j=0;j<lim;j+=(i<<1))
{
int gi=1;
for(rg int k=0;k<i;++k,gi=(ll)gi*gn%mod)
{
int x=t[j+k],y=(ll)gi*t[j+i+k]%mod;
t[j+k]=module(x,y);
t[j+i+k]=module(x,mod-y);
}
}
}
if(type==-1)
{
int inv=qpow(lim,mod-2);
for(rg int i=0;i<lim;++i)
t[i]=(ll)t[i]*inv%mod;
}
}
int X[4*maxn],Y[4*maxn];
inline void mul(int*x, int*y, int n, int m)
{
memset(X,0,sizeof(X));
memset(Y,0,sizeof(Y));
int lim = 1, L = 0; //L=0必须写,局部变量默认值很可能不是0
while(lim <= n + m) lim <<= 1, L++; //lim为大于(n+m)的2的幂,所以最多需要4倍空间
for(int i = 0; i < lim; i++) rev[i] = (rev[i >> 1] >> 1) | ((i & 1) << (L - 1));
for(rg int i=0;i<lim;++i) X[i]=x[i],Y[i]=y[i];
NTT(X,lim,1);
NTT(Y,lim,1);
for(rg int i=0;i<lim;++i) X[i]=(ll)X[i]*Y[i]%mod;
NTT(X,lim,-1);
for(rg int i=0;i<lim;++i) x[i]=X[i];
}
int n, m;
int a[4*maxn], b[4*maxn];
int main()
{
scanf("%d%d", &n, &m);
for(int i = 0;i <= n;i++) scanf("%d", &a[i]);
for(int i = 0;i <= m;i++) scanf("%d", &b[i]);
mul(a, b, n, m);
for(int i = 0;i <= n+m;i++) printf("%d ", a[i]);
return 0;
}