题目大意:给两个60000位以内的整数,求乘积
快速傅立叶变换入门题。递归常数巨大,因此需要采取迭代写法。
考虑进行递归FFT的过程:
①:将偶数为划分到左边,奇数位划分到右边
②:递归左右两边(设为a,b),设当前长度为len,则c[i]=a[i]+wi * b[i],c[i+(len/2)]=a[i]-wi * b[i]
观察这个过程,如果把b当成在a右边len/2个长度,可以发现每次操作所需要的两个数,操作后又正好落回了这两个位置。这说明操作可以只用一个数组完成。将递归过程想象成一棵树,根节点长度为n,有两个长度为n/2的儿子,如此一直到底。。。如果能预处理出每个元素在最底层的位置,就可以自下向上递推。
通过前人的观察与证明,对于长度为2^n的序列,初始第i项在最底层的位置是i的后n位二进制反转。可以用pos[i]=(pos[i>>1]>>1);if(i&1) pos[i]|=maxlen>>1;来O(n)递推。
迭代过程:第一层循环自下到上枚举每一层的节点长度i,最底层不用管,从2一直枚举到maxlen,每次乘以2,同时计算出该层的单位根wi。设step=i/2,则第x个元素与第x+step个元素配对。第二层循环枚举每个节点j,从0到 < maxlen,每次增加i,同时初始化w=1。第三层循环处理当前节点的所有元素k,从j到 < j+step(也就是这个节点的前一半),每次把k和k+step配对,使a=c[k],b=w*c[k+step],则c[k]=a+b,c[k+step]=a-b,同时将w乘上单位根wi。
不要忘记IDFT后每一位要除以maxlen。
注意由于算的是高精度乘法,还要对每一位取整后进行十进制进位。
#include<cstdio>
#include<cmath>
#include<algorithm>
#define gm 1<<17
using namespace std;
struct complex
{
double a,b;
complex(const double &a=0.0,const double &b=0.0):a(a),b(b){}
complex operator + (const complex &y) const
{
return complex(a+y.a,b+y.b);
}
complex operator - (const complex &y) const
{
return complex(a-y.a,b-y.b);
}
complex operator * (const complex &y) const
{
return complex(a*y.a-b*y.b,a*y.b+b*y.a);
}
}a[gm],b[gm],c[gm];
const double pi=acos(-1);
const double DFT=2.0,IDFT=-2.0;
double transform_mode;
size_t pos[gm];
inline void fast_transform(complex x[],size_t len)
{
for(size_t i=0;i<len;++i)
if(i<pos[i]) swap(x[i],x[pos[i]]);
for(size_t i=2;i<=len;i<<=1)
{
size_t step=i>>1;
complex wm(cos(2*pi/i),sin(transform_mode*pi/i));
for(size_t j=0;j<len;j+=i)
{
size_t lim=j+step;
complex w(1,0);
for(size_t k=j;k<lim;++k)
{
complex a=x[k],b=w*x[k+step];
x[k]=a+b,x[k+step]=a-b;
w=w*wm;
}
}
}
if(transform_mode==IDFT)
for(size_t i=0;i<len;++i) x[i].a/=len;
}
size_t n;
char __c;
size_t ans[gm],top=0;
int main()
{
scanf("%u",&n);
#define c __c
do c=getchar();while(c<'0'||c>'9');
for(size_t i=0;i<n;++i)
a[i].a=floor(c-'0'),c=getchar();
do c=getchar();while(c<'0'||c>'9');
for(size_t i=0;i<n;++i)
b[i].a=floor(c-'0'),c=getchar();
#undef c
size_t len=1;
while(len<(n<<1)) len<<=1;
for(size_t i=0;i<len;++i)
{
pos[i]=pos[i>>1]>>1;
if(i&1) pos[i]|=len>>1;
}
transform_mode=DFT;
fast_transform(a,len);
fast_transform(b,len);
for(size_t i=0;i<len;++i) c[i]=a[i]*b[i];
transform_mode=IDFT;
fast_transform(c,len);
for(size_t i=(n<<1)-2;~i;--i) ans[top++]=floor(c[i].a+0.5);
for(size_t i=0;i<top;++i) ans[i+1]+=ans[i]/10,ans[i]%=10;
if(ans[top]) printf("%u",ans[top]);
for(size_t i=top-1;~i;--i) printf("%u",ans[i]);
putchar('\n');
return 0;
}