题意:
两个大数相乘。
思路:
这题是典型的FFT。
所谓FFT,也就是快速傅里叶变换,是一种加速两个多项式系数相乘的算法,可以通过插值表示多项式的思想,将多项式从系数形式转化成点值形式,点值形式的多项式相乘只需要O(n),而FFT的关键就在于点值形式和系数形式之间的相互转化,可以利用分治的思想达到O(logn)的复杂度。
关于FFT的详细推导可以学习算法导论第30章。
一个题目如果需要用FFT来加速运算,首先要保证可以转化成多项式相乘的类型。
然后依据以下四个步骤:
1. 补全多项式到长度为2^n,不够的情况高位补0,复杂度O(n)
2. DFT过程,将系数表达式转化成点值表达式,复杂度O(nlogn)
3. 点值表达式相乘,得到结果的点值表达式形式,复杂度O(n)
4. 逆DFT过程,将点值表达式还原成系数表达式,复杂度O(nlogn)
就这道题来说,这里两个数可以分别看成两个多项式,每一位上的数都可以看作是多项式中一项的系数,这样大数乘法其实就是两个多项式的乘法,可以利用FFT来加速运算。
不要忘了得到的系数表达式的系数可能不满足十进制的要求,这时候需要进位,以及去掉前导0。
代码:
#include <bits/stdc++.h>
using namespace std;
#define pi acos (-1)
#define maxn 200010
struct plex { // 定义复数类
double x, y;
plex (double _x = 0.0, double _y = 0.0) : x (_x), y (_y) {}
plex operator + (const plex &a) const {
return plex (x + a.x, y + a.y);
}
plex operator - (const plex &a) const {
return plex (x - a.x, y - a.y);
}
plex operator * (const plex &a) const {
return plex (x * a.x - y * a.y, x * a.y + y * a.x);
}
};
void change (plex *y, int len) {
for (int i = 1, j = len / 2; i < len - 1; i++) {
if (i < j) swap(y[i], y[j]);
int k = len / 2;
while (j >= k) {
j -= k;
k /= 2;
}
if (j < k) j += k;
}
}
void fft(plex y[], int len, int on) { // FFT过程,on==1时,将系数表达转换成点值表达,on==-1时,将点值表达转换成系数表达
change(y, len);
for(int h = 2; h <= len; h <<= 1) {
plex wn(cos(-on * 2 * pi / h), sin(-on * 2 * pi / h));
for(int j = 0; j < len; j += h) {
plex w(1, 0);
for(int k = j; k < j + h / 2; k++) {
plex u = y[k];
plex t = w * y[k + h / 2];
y[k] = u + t;
y[k + h / 2] = u - t;
w = w * wn;
}
}
}
if(on == -1) {
for(int i = 0; i < len; i++)
y[i].x /= len;
}
}
char a[maxn], b[maxn];
plex x1[maxn], x2[maxn];
int ans[maxn];
int main () {
while (scanf ("%s%s", a, b) == 2) {
int len = 2, l1 = strlen(a), l2 = strlen(b);
while (len < l1 * 2 || len < l2 * 2) len *= 2; // 扩充多项式长度到2^n
for (int i = 0; i < l1; i++)
x1[i] = plex(a[l1 - i - 1] - '0', 0); // 补0
for (int i = l1; i < len; i++)
x1[i] = plex(0, 0);
for (int i = 0; i < l2; i++)
x2[i] = plex(b[l2 - i - 1] - '0', 0);
for (int i = l2; i < len; i++)
x2[i] = plex(0, 0);
fft(x1, len, 1); // DFT过程
fft(x2, len, 1);
for (int i = 0; i < len; i++) // 点值形式下相乘
x1[i] = x1[i] * x2[i];
fft(x1, len, -1); // 逆DFT过程
for (int i = 0; i < len; i++) ans[i] = (int)(x1[i].x + 0.5);
for (int i = 0; i < len; i++) { // 需要进位
if (ans[i] >= 10) {
ans[i + 1] += ans[i] / 10;
ans[i] %= 10;
}
}
len = l1 + l2 - 1;
while (ans[len] <= 0 && len > 0) len--;
for (int i = len; i >= 0; i--)
printf("%d", ans[i]);
puts("");
}
return 0;
}