之前向学习了一个 $\text{FFT}$ 的优化,但是像我这么弱的人每次打 $\text{FFT}$ 板子的时候都会忘记这个东西,在这里记一下。
我们知道普通的 $\text{FFT}$ 会用到原根 $\omega_n^0,\omega_n^1\cdots\omega_n^{n-1}$ 然后这些东西会在枚举步长的时候通过 $\omega_n = e^{\frac{2\pi}{n}}$ 和 $e^{\theta i} = \cos \theta + i\sin \theta$ 这两个公式一次一次算出来。
然而我们知道,调用三角函数是非常慢的,每次计算的时候,即使你是手写的 $\text{complex}$ 也会非常慢,这就使得这种 $\text{FFT}$ 的常数巨大无比。
所以我们就预处理一下每次需要用到的 $\omega$ ,把每一种步长需要用到的 $\omega$ 扔到同一个数组 $W$ 里,有每种步长的 $\omega$ 连续。而因为 $\sum_{i=0}^{n} 2^i = 2^{i + 1} - 1$ ,所以每次需要访问步长为 $s$ 的 $\omega$ 时候只要访问 $W[s]$ 就可以了,将一个指针指向他,而后面的只要把指针一步一步往后移即可。
这是 $\text{DFT}$ 的时候用的,但是我们知道 $\text{IDFT}$ 的时候用的 $\omega$ 和 $\text{DFT}$ 的时候是不一样的。
然而我们不需要重新处理 $\text{IDFT}$ 用的 $\omega$ ,只需要把需要 $\text{FFT}$ 的 $A$ 从 $1$ 到 $n - 1$ 的值 $\text{reverse}$ 一下就行了。原理是本来 $\text{IDFT}$ 的时候需要把 $\omega$ 翻过来,但是那个有点麻烦,于是我们就把 $A$ 给翻过来就行了。由于 $\text{FFT}$ 可以被理解为一个特殊的矩阵乘法,所以你顺着搞下来和反着搞回去最后的结果是一样的,所以它是对的。
然后下面贴了一道水题的代码来帮助理解:
例:求有多少个从 $1,2,\cdots,n$ 中取三个元素的排列 $(a,b,c)$ 满足 $x_a=x_b-x_c$。由于是排列,所以 $(a,b,c)$ 与 $(c,b,a)$ 视为两组解。
#include <algorithm>
#include <cstdio>
#include <cmath>
#include <cstdlib>
#include <cstring>
#include <ctime>
#include <iostream>
#include <queue>
#include <set>
#include <stack>
#define R register
#define ll long long
#define db double
#define ld long double
#define sqr(_x) (_x) * (_x)
#define Cmax(_a, _b) ((_a) < (_b) ? (_a) = (_b), 1 : 0)
#define Cmin(_a, _b) ((_a) > (_b) ? (_a) = (_b), 1 : 0)
#define Max(_a, _b) ((_a) > (_b) ? (_a) : (_b))
#define Min(_a, _b) ((_a) < (_b) ? (_a) : (_b))
#define Abs(_x) (_x < 0 ? (-(_x)) : (_x))
using namespace std;
namespace Dntcry
{
inline int read()
{
R int a = 0, b = 1; R char c = getchar();
for(; c < '0' || c > '9'; c = getchar()) (c == '-') ? b = -1 : 0;
for(; c >= '0' && c <= '9'; c = getchar()) a = (a << 1) + (a << 3) + c - '0';
return a * b;
}
inline ll lread()
{
R ll a = 0, b = 1; R char c = getchar();
for(; c < '0' || c > '9'; c = getchar()) (c == '-') ? b = -1 : 0;
for(; c >= '0' && c <= '9'; c = getchar()) a = (a << 1) + (a << 3) + c - '0';
return a * b;
}
const int Maxn = 1000010, Maxl = 600010, lim = 100000;
const ld pi = acos(-1);
struct Complex
{
ld real, imag;
Complex operator + (const Complex &b) const
{
return (Complex) {real + b.real, imag + b.imag};
}
Complex operator - (const Complex &b) const
{
return (Complex) {real - b.real, imag - b.imag};
}
Complex operator * (const Complex &b) const
{
return (Complex) {real * b.real - imag * b.imag, b.real * imag + real * b.imag};
}
}C[Maxl], A[Maxl], w[Maxl], wl;
int n, m, x[Maxn], Cnt[Maxl], len, bit, rev[Maxl], zero;
ll Ans[Maxn], Sum;
void Get_Rev(R int bit)
{
for(R int i = 0; i < len; i++)
rev[i] = (rev[i >> 1] >> 1) | ((i & 1) << bit - 1);
return ;
}
void FFT(R Complex *K, R ld DFT)
{
for(R int i = 0; i < len; i++) if(i < rev[i]) swap(K[i], K[rev[i]])
R Complex *W;
for(R int i = 2; i <= len; i <<= 1)
{
for(R int j = 0, step = i >> 1; j < len; j += i)
{
W = w + step;
for(R int k = j; k < j + step; W++, k++)
{
R Complex G = K[k], H = *W * K[k + step];
K[k] = G + H;
K[k + step] = G - H;
}
}
}
if(DFT == -1.0)
for(R int i = 0; i < len; i++)
K[i].real /= 1.0 * len, K[i].imag /= 1.0 * len;
return ;
}
int Main()
{
n = read();
for(R int i = 1; i <= n; i++)
{
x[i] = read(); if(!x[i]) zero++;
x[i] += lim, m = Max(m, x[i]);
Cnt[x[i]]++;
} m++;
for(bit = 0, len = 1; (1 << bit) < (m << 1); bit++) len <<= 1;
R int tmp = len >> 1;
w[tmp] = (Complex) {1.0, 0.0};
wl = w[++tmp] = (Complex) {cos(2.0 * pi / len), sin(2.0 * pi / len)};
for(tmp++; tmp < len; tmp++) w[tmp] = w[tmp - 1] * wl;
for(R int i = (len >> 1) - 1; i; i--) w[i] = w[i << 1];
Get_Rev(bit);
for(R int i = 0; i < m; i++) A[i] = (Complex) {1.0 * Cnt[i], 0.0};
FFT(A, 1.0);
C[0] = A[0] * A[0];
for(R int i = 1; i < len; i++) C[i] = A[len - i] * A[len - i];
FFT(C, -1.0);
for(R int i = 0; i < len; i++) Ans[i] = (ll)(C[i].real + 0.5);
for(R int i = 1; i <= n; i++) Ans[x[i] << 1]--;
for(R int i = 1; i <= n; i++) Sum += Ans[x[i] + lim];
Sum -= 2ll * zero * (n - 1);
printf("%lld\n", Sum);
return 0;
}
}
int main()
{
return Dntcry :: Main();
}