快速傅里叶变换在这里的主要用处就是可以快速求出两个多项式的乘积,可以把两个大数转换成a1 + a2*x + a3*x^2......的形式,利用FFT快速求值。
#include <iostream>
#include <cstdio>
#include <cstring>
#include <algorithm>
#include <cmath>
using namespace std;
const int maxn = 200000 + 7;
const double PI = acos(-1.0);
struct complex {
double a, b;
complex(double aa = 0, double bb = 0) : a(aa), b(bb) { }
complex operator+(const complex &e) {
return complex(a + e.a, b + e.b);
}
complex operator-(const complex &e) {
return complex(a - e.a, b - e.b);
}
complex operator*(const complex &e) {
return complex(a * e.a - b * e.b, a * e.b + b * e.a);
}
};
void change(complex y[], int len) {
for(int i = 1, j = len / 2; i < len - 1; ++i) {
if(i < j) swap(y[i], y[j]);
int k = len / 2;
while(j >= k) {
j -= k;
k /= 2;
}
if(j < k) j += k;
}
return ;
}
//FFT快速傅里叶变换的模板,用以将多项式系数转换成单位根,这样得到的两个序列逐个相乘到得就是系数
void FFT(complex y[], int len, int on) {
change(y, len);
for(int h = 2; h <= len; h <<= 1) {
complex wn(cos(-on * 2 * PI / h), sin(-on * 2 * PI / h));
for(int j = 0; j < len; j += h) {
complex w(1, 0);
for(int k = j; k < j + h / 2; ++k) {
complex u = y[k];
complex t = w * y[k+h/2];
y[k] = u + t;
y[k+h/2] = u - t;
w = w * wn;
}
}
}
if(on == -1) {
for(int i = 0; i < len; ++i)
y[i].a /= len;
}
}
int num[maxn];
complex x[maxn], y[maxn];
char s1[maxn/4], s2[maxn/4];
int main() {
while(scanf("%s%s", s1, s2) != EOF) {
int len1 = strlen(s1), len2 = strlen(s2);
int len = 1;
//这里求最接近且大于n+m-1的2^k,方便二叉树形式的运用吧
while(len < len1 + len2) len <<= 1;
for(int i = 0; i < len1; ++i)
x[i] = complex(s1[len1-i-1] - '0', 0);
for(int i = len1; i < len; ++i)
x[i] = complex(0, 0);
for(int i = 0; i < len2; ++i)
y[i] = complex(s2[len2-i-1] - '0', 0);
for(int i = len2; i < len; ++i)
y[i] = complex(0, 0);
FFT(x, len, 1);//FFT转换成单位根形式
FFT(y, len, 1);
for(int i = 0; i < len; ++i)
x[i] = x[i] * y[i];//卷积
FFT(x, len, -1);//结果转换回来
int temp = 0;
for(int i = 0; i < len; ++i) {
temp = temp + (int)(x[i].a + 0.5);
num[i] = temp % 10;
temp /= 10;
}
while(temp) {
num[len++] = temp % 10;
temp /= 10;
}
int i = len - 1;
while(num[i] == 0 && i > 0) i--;
for(; i >= 0; --i)
printf("%d", num[i]);
printf("\n");
}
return 0;
}