8.13.3 ACM-ICPC 数学 多项式与生成函数 快速傅里叶变换
前置知识:复数
快速傅里叶变换(FFT)是一种在 𝑂(𝑛log𝑛)O(nlogn) 时间内计算两个 𝑛n 度多项式乘法的高效算法,比朴素的 𝑂(𝑛2)O(n2) 算法更高效。由于两个整数的乘法也可以被看作多项式乘法,这个算法也可用于加速大整数的乘法计算。
引入
考虑两个多项式 𝐴(𝑥)A(x) 和 𝐵(𝑥)B(x):
两个多项式相乘的积 𝐶(𝑥)=𝐴(𝑥)×𝐵(𝑥)C(x)=A(x)×B(x) 可以在 𝑂(𝑛2)O(n2) 时间复杂度内计算得出(这里 𝑛n 为 𝐴A 或 𝐵B 多项式的度):
很明显,多项式 𝐶C 的系数 𝑐𝑖ci 满足 𝑐𝑖=∑𝑗=0𝑖𝑎𝑗𝑏𝑖−𝑗ci=∑j=0iajbi−j。对于朴素算法而言,计算每一项的时间复杂度都是 𝑂(𝑛)O(n),一共有 𝑂(𝑛)O(n) 项,那么总的时间复杂度为 𝑂(𝑛2)O(n2)。
如果使用快速傅里叶变换(FFT),可以将复杂度降低到 𝑂(𝑛log𝑛)O(nlogn)。
傅里叶变换
傅里叶变换(Fourier Transform)是一种分析信号的方法,可以用来分析信号的成分,也可以用这些成分合成信号。傅里叶变换用正弦波作为信号的成分。
设 𝑓(𝑡)f(t) 是关于时间 𝑡t 的函数,则傅里叶变换可以检测频率 𝜔ω 的周期在 𝑓(𝑡)f(t) 中出现的程度:
它的逆变换是: 𝑓(𝑡)=𝐹−1[𝐹(𝜔)]=12𝜋∫−∞∞𝐹(𝜔)𝑒i𝜔𝑡𝑑𝜔f(t)=F−1[F(ω)]=2π1∫−∞∞F(ω)eiωtdω
傅里叶变换可以将时域的卷积转化为频域的乘积。
离散傅里叶变换(DFT)
离散傅里叶变换(DFT)将信号的时域采样变换为其离散时间傅里叶变换(DTFT)的频域采样。设 {𝑥𝑛}𝑛=0𝑁−1{xn}n=0N−1 是某一满足有限性条件的序列,它的离散傅里叶变换(DFT)为:
逆离散傅里叶变换(IDFT)为:
离散傅里叶变换可以将多项式在单位根处进行求值。
快速傅里叶变换(FFT)
FFT 是高效实现 DFT 的算法。它对傅里叶变换的理论并没有新的发现,但是对于在计算机系统或者数字系统中应用离散傅里叶变换有很大进步。快速数论变换(NTT)是快速傅里叶变换(FFT)在数论基础上的实现。
在 1965 年,Cooley 和 Tukey 发表了快速傅里叶变换算法。事实上,FFT 早在这之前就被发现过了,但是在当时现代计算机并未问世,人们没有意识到 FFT 的重要性。一些调查者认为 FFT 是由 Runge 和 König 在 1924 年发现的。但事实上高斯早在 1805 年就发明了这个算法,但一直没有发表。
分治法实现
FFT 算法的基本思想是分治。将多项式分为奇次项和偶次项处理。例如,对于 8 项多项式:
利用单位根的性质,进行递归 DFT。
代码实现
递归版 FFT:
#include <cmath>
#include <complex>
typedef std::complex<double> Comp;
const Comp I(0, 1); // i
const int MAX_N = 1 << 20;
Comp tmp[MAX_N];
void DFT(Comp* f, int n, int rev) {
if (n == 1) return;
for (int i = 0; i < n; ++i) tmp[i] = f[i];
for (int i = 0; i < n; ++i) {
if (i & 1)
f[n / 2 + i / 2] = tmp[i];
else
f[i / 2] = tmp[i];
}
Comp *g = f, *h = f + n / 2;
DFT(g, n / 2, rev), DFT(h, n / 2, rev);
Comp cur(1, 0), step(cos(2 * M_PI / n), sin(2 * M_PI * rev / n));
for (int k = 0; k < n / 2; ++k) {
tmp[k] = g[k] + cur * h[k];
tmp[k + n / 2] = g[k] - cur * h[k];
cur *= step;
}
for (int i = 0; i < n; ++i) f[i] = tmp[i];
}
非递归版 FFT:
#include <cmath>
#include <complex>
#include <algorithm>
typedef std::complex<double> Comp;
const double PI = acos(-1.0);
const int MAX_N = 1 << 20;
void change(Comp y[], int len) {
int i, j, k;
for (int i = 1, j = len / 2; i < len - 1; i++) {
if (i < j) std::swap(y[i], y[j]);
k = len / 2;
while (j >= k) {
j -= k;
k /= 2;
}
if (j < k) j += k;
}
}
void fft(Comp y[], int len, int on) {
change(y, len);
for (int h = 2; h <= len; h <<= 1) {
Comp wn(cos(2 * PI / h), sin(on * 2 * PI / h));
for (int j = 0; j < len; j += h) {
Comp w(1, 0);
for (int k = j; k < j + h / 2; k++) {
Comp u = y[k];
Comp t = w * y[k + h / 2];
y[k] = u + t;
y[k + h / 2] = u - t;
w *= wn;
}
}
}
if (on == -1) {
for (int i = 0; i < len; i++) {
y[i] /= len;
}
}
}
FFT 模板(HDU 1402 - A * B Problem Plus):
#include <cmath>
#include <cstdio>
#include <cstring>
#include <iostream>
const double PI = acos(-1.0);
struct Complex {
double x, y;
Complex(double _x = 0.0, double _y = 0.0) : x(_x), y(_y) {}
Complex operator - (const Complex &b) const { return Complex(x - b.x, y - b.y); }
Complex operator + (const Complex &b) const { return Complex(x + b.x, y + b.y); }
Complex operator * (const Complex &b) const { return Complex(x * b.x - y * b.y, x * b.y + y * b.x); }
};
void change(Complex y[], int len) {
int i, j, k;
for (int i = 1, j = len / 2; i < len - 1; i++) {
if (i < j) std::swap(y[i], y[j]);
k = len / 2;
while (j >= k) {
j -= k;
k /= 2;
}
if (j < k) j += k;
}
}
void fft(Complex y[], int len, int on) {
change(y, len);
for (int h = 2; h <= len; h <<= 1) {
Complex wn(cos(2 * PI / h), sin(on * 2 * PI / h));
for (int j = 0; j < len; j += h) {
Complex w(1, 0);
for (int k = j; k < j + h / 2; k++) {
Complex u = y[k];
Complex t = w * y[k + h / 2];
y[k] = u + t;
y[k + h / 2] = u - t;
w *= wn;
}
}
}
if (on == -1) {
for (int i = 0; i < len; i++) {
y[i].x /= len;
}
}
}
const int MAXN = 200020;
Complex x1[MAXN], x2[MAXN];
char str1[MAXN / 2], str2[MAXN / 2];
int sum[MAXN];
int main() {
while (scanf("%s%s", str1, str2) == 2) {
int len1 = strlen(str1);
int len2 = strlen(str2);
int len = 1;
while (len < len1 * 2 || len < len2 * 2) len <<= 1;
for (int i = 0; i < len1; i++) x1[i] = Complex(str1[len1 - 1 - i] - '0', 0);
for (int i = len1; i < len; i++) x1[i] = Complex(0, 0);
for (int i = 0; i < len2; i++) x2[i] = Complex(str2[len2 - 1 - i] - '0', 0);
for (int i = len2; i < len; i++) x2[i] = Complex(0, 0);
fft(x1, len, 1);
fft(x2, len, 1);
for (int i = 0; i < len; i++) x1[i] = x1[i] * x2[i];
fft(x1, len, -1);
for (int i = 0; i < len; i++) sum[i] = int(x1[i].x + 0.5);
for (int i = 0; i < len; i++) {
sum[i + 1] += sum[i] / 10;
sum[i] %= 10;
}
len = len1 + len2 - 1;
while (sum[len] == 0 && len > 0) len--;
for (int i = len; i >= 0; i--) printf("%c", sum[i] + '0');
printf("\n");
}
return 0;
}
参考文献
- 桃酱的算法笔记