这一周一直在研究FFT……与其说是一周在研究,不如说是研究了2天调试了某蛋疼错误5天。 >_< ||
不得不说FFT的思想真是太神了,基本介绍这里就不写了,盾哥(FFT csdn jasonzhu8)和秋哥(FFT neroysq)写的已经很清楚了。
于是简单总结一下。 >_< ~
FFT处理多项式乘法,利用点值乘法的方便性,在点值和系数表示之间转换,从而得以加速。
所选点取复数域下1的n次根,记作w(n,k) (k=0~n-1),满足如下性质:
计算式:w(n,k) = cos(2*π*k / n) + i * sin(2*π*k / n)(复数的旋转,显然
折半性:w(n,k) = w(2n, 2k) (显然
周期性:w(n,k) = w(n, k+n) (三角函数周期,显然
半反性:w(n, k) = -w(n, k+n/2) (三角函数半周期性质,显然
逆矩阵:将点值过程看做列向量A与矩阵Y相乘得到B,那么插值过程为B*Y'(逆矩阵),这个逆矩阵是:原矩阵w(n,k)变为w(n,-k)后乘上(1/n)。 (完全不会证明 = =||
在实现的时候,方便起见,把n补成2的整数幂。
有了上面这些性质,我们每次将多项式的项奇偶分组,有:
F(x)= F0(x*x)+ x * F1(x*x)
F(w(n,k))= F0(w(n/2,k))+ w(n,k)* F1(w(n/2, k))= F0(w(n,2k))+ w(n,k)* F1(w(n,2k))
对于前一半项可以直接用两个子问题的解计算合并出来,后一半项因为有半反性和周期性,可以利用前半的结果一起算出。
每次把原问题化成两个相同递归子问题,合成是O(n)的,故总复杂度:O(n*logn)。
Code:
#include<cstdio>
#include<cstring>
#include<cmath>
using namespace std;
const int SN=10010, mo=1000;
const double pi=acos(-1);
char s[SN]; double cf[10]={1,10,100,1000,10000};
struct cp { double x,y; } A[SN],B[SN],w[SN],t[SN],wt;
int n,i,len,ex,top; long long ans[SN];
cp operator + (cp a, cp b) { return (cp){a.x+b.x, a.y+b.y}; }
cp operator - (cp a, cp b) { return (cp){a.x-b.x, a.y-b.y}; }
cp operator * (cp a, cp b) { return (cp){a.x*b.x-a.y*b.y, a.x*b.y+a.y*b.x}; }
void get(cp ar[])
{
int i,q=0,t=-1;
scanf("%s", s); len=strlen(s);
for(i=len-1; i>=0; i--)
{
if(++t>2) t=0,q++;
ar[q].x += cf[t]*(double)(s[i]-'0');
}
if(q+1>ex) ex=q+1;
}
void fft(cp a[], int s, int c)
{
int num=n>>c, bs=1<<c, i, p;
if(num==1) return;
fft(a, s, c+1);
fft(a, s+bs, c+1);
for(i=0; i<(num>>1); i++)
{
p=s+i*(bs<<1);
wt=w[i<<c]*a[p+bs];
t[i]=a[p]+wt;
t[i+(num>>1)]=a[p]-wt;
}
for(i=0; i<num; i++) a[s+i*bs]=t[i];
}
int main()
{
freopen("fft.in" , "r", stdin);
freopen("fft.out", "w", stdout);
get(A); get(B);
for(n=1; n<(ex<<1); n<<=1);
for(i=0; i<n; i++) w[i].x=cos( (pi*2/n)*i ), w[i].y=sin( (pi*2/n)*i );
fft(A, 0, 0);
fft(B, 0, 0);
for(i=0; i<n; i++) A[i]=A[i]*B[i], w[i].y=-w[i].y;
fft(A, 0, 0);
for(i=0; i<n; i++) if(A[i].x>0.5) top=i;
for(i=0; i<=top; i++) ans[i]=(long long)( (A[i].x+0.1)/n );
for(i=0; i<=top; i++)
if(ans[i]>=mo) top+=(i==top), ans[i+1]+=ans[i]/mo, ans[i]%=mo;
printf("%d", ans[top]);
for(i=top-1; i>=0; i--) printf("%03d", ans[i]);
return 0;
}