快速傅里叶变换,可以将多项式相乘的时间复杂度从最简单的O(n^2)优化到O(nlgn),详细过程参考算法导论.
FFT的流程大致是:
1):构造多项式,复杂度O(n)
2):求两个多项式的DFT,复杂度O(nlgn)
3):构造多项式乘积的点值表达式,复杂度O(n)
4):求点值表达式的IDFT,复杂度O(nlgn).
下面是两道最简单的习题:
HDU 1402:点击打开链接
求两个大数乘积.
因为一个大数可以看成是一个多项式,每一位上的值都表示对应次数下的系数,因此可以用FFT加速.
本体的一个坑点就是
len = l1+l2-1;
这句代码,可能是精度问题在len更加高位的地方出现了非0值.
#include <bits/stdc++.h>
using namespace std;
#define pi acos (-1)
#define maxn 200010
struct plex {
double x, y;
plex (double _x = 0.0, double _y = 0.0) : x (_x), y (_y) {}
plex operator + (const plex &a) const {
return plex (x+a.x, y+a.y);
}
plex operator - (const plex &a) const {
return plex (x-a.x, y-a.y);
}
plex operator * (const plex &a) const {
return plex (x*a.x-y*a.y, x*a.y+y*a.x);
}
};
void change (plex *y, int len) {
int i, j, k;
for(i = 1, j = len / 2; i < len - 1; i++) {
if (i < j) swap(y[i], y[j]);
k = len / 2;
while (j >= k) {
j -= k;
k /= 2;
}
if (j < k) j += k;
}
}
void fft(plex y[],int len,int on)
{
change(y,len);
for(int h = 2; h <= len; h <<= 1)
{
plex wn(cos(-on*2*pi/h),sin(-on*2*pi/h));
for(int j = 0;j < len;j+=h)
{
plex w(1,0);
for(int k = j;k < j+h/2;k++)
{
plex u = y[k];
plex t = w*y[k+h/2];
y[k] = u+t;
y[k+h/2] = u-t;
w = w*wn;
}
}
}
if(on == -1)
for(int i = 0;i < len;i++)
y[i].x /= len;
}
char a[maxn], b[maxn];
plex x1[maxn], x2[maxn];
int ans[maxn];
int main () {
while (scanf ("%s%s", a, b) == 2) {
int len = 2, l1 = strlen (a), l2 = strlen (b);
while (len < l1*2 || len < l2*2)
len <<= 1;
for (int i = 0; i < l1; i++) {
x1[i] = plex (a[l1-1-i]-'0', 0);
}
for (int i = l1; i < len; i++)
x1[i] = plex (0, 0);
for (int i = 0; i < l2; i++) {
x2[i] = plex (b[l2-1-i]-'0', 0);
}
for (int i = l2; i < len; i++)
x2[i] = plex (0, 0);
fft (x1, len, 1);
fft (x2, len, 1);
for (int i = 0; i < len; i++)
x1[i] = x1[i]*x2[i];
fft (x1, len, -1);
for (int i = 0; i < len; i++) {
ans[i] = (int)(x1[i].x+0.5);
}
for (int i = 0; i < len; i++) {
if (ans[i] >= 10) {
ans[i+1] += ans[i]/10;
ans[i] %= 10;
}
}
len = l1+l2-1;
while (ans[len] <= 0 && len > 0)
len--;
for (int i = len; i >= 0; i--) {
printf ("%d", ans[i]);
}
printf ("\n");
}
return 0;
}
HDU 4609: 点击打开链接
题意是给出n个长度,任取3个求能组成三角形的概率.
首先记录下每个长度的数量,然后用FFT加速求出任取两个长度下的情况,这里面有重复:
首先减去两次都取同一根的情况,减完之后的结果都/2.
最后只需要所有的情况减去不能组成三角形的情况,将最初的长度序列排序后从小到大枚举下标,假设这条边是最长边,那么如果所有两条边长度小于等于这条边的情况就应该减去,这里用前缀和统计下就好了.
#include <bits/stdc++.h>
using namespace std;
#define pi acos (-1)
#define maxn 611111
struct plex {
double x, y;
plex (double _x = 0.0, double _y = 0.0) : x (_x), y (_y) {}
plex operator + (const plex &a) const {
return plex (x+a.x, y+a.y);
}
plex operator - (const plex &a) const {
return plex (x-a.x, y-a.y);
}
plex operator * (const plex &a) const {
return plex (x*a.x-y*a.y, x*a.y+y*a.x);
}
};
void change (plex y[], int len) {
if (len == 1)
return ;
plex a1[len/2], a2[len/2];
for (int i = 0; i < len; i += 2) {
a1[i/2] = y[i];
a2[i/2] = y[i+1];
}
change (a1, len>>1);
change (a2, len>>1);
for (int i = 0; i < len/2; i++) {
y[i] = a1[i];
y[i+len/2] = a2[i];
}
return ;
}
void fft(plex y[],int len,int on)
{
change(y,len);
for(int h = 2; h <= len; h <<= 1)
{
plex wn(cos(on*2*pi/h),sin(on*2*pi/h));
for(int j = 0;j < len;j+=h)
{
plex w(1,0);
for(int k = j;k < j+h/2;k++)
{
plex u = y[k];
plex t = w*y[k+h/2];
y[k] = u+t;
y[k+h/2] = u-t;
w = w*wn;
}
}
}
if(on == -1)
for(int i = 0;i < len;i++)
y[i].x /= len;
}
long long num[maxn], sum[maxn];
int a[maxn];
plex x[maxn];
long long n;
int main () {
//freopen ("in.txt", "r", stdin);
int t;
scanf ("%d", &t);
while (t--) {
scanf ("%lld", &n);
long long Max = 0;
memset (num, 0, sizeof num);
for (int i = 1; i <= n; i++) {
scanf ("%d", &a[i]);
num[a[i]]++;
Max = max (Max, (long long)a[i]);
}
Max++;
int len = 2;
while (len < Max*2)
len <<= 1;
for (int i = 0; i < len; i++) {
x[i] = plex (num[i], 0);
}
fft (x, len, 1);
for (int i = 0; i < len; i++) {
x[i] = x[i]*x[i];
}
fft (x, len, -1);
for (int i = 0; i < len; i++) {
num[i] = (long long) (x[i].x+0.5);
}
for (int i = 1; i <= n; i++) {//两次取同一个
num[a[i]+a[i]]--;
}
for (int i = 0; i < len; i++) {//重复计算
num[i] /= 2;
}
sum[0] = 0;
for (int i = 1; i < len; i++) {
sum[i] = sum[i-1]+num[i];
}
sort (a+1, a+1+n);
long long tot = n*(n-1)*(n-2)/6, ans = tot;
for (int i = 3; i <= n; i++) {
ans -= sum[a[i]];
}
printf ("%.7f\n", ans*1.0/tot);
}
return 0;
}
但是FFT有一个很致命的弱点就是会产生精度误差,在换成long double都不行的时候就需要用到NTT。
NTT就是用数论域中的原根代替FFT中的单位负根,其他的代码完全相同。
求原根的代码:
#include <cstdio>
#include <cmath>
#include <algorithm>
#include <iostream>
#include <vector>
#include <cstring>
using namespace std;
int P;
const int NUM = 32170;
int prime[NUM/4];
bool f[NUM];
int pNum = 0;
void getPrime () {//线性筛选素数
for (int i = 2; i < NUM; ++ i) {
if (!f[i]) {
f[i] = 1;
prime[pNum++] = i;
}
for (int j = 0; j < pNum && i*prime[j] < NUM; ++ j) {
f[i*prime[j]] = 1;
if (i%prime[j] == 0) {
break;
}
}
}
}
long long getProduct(int a,int b,int P) {//快速求次幂mod
long long ans = 1;
long long tmp = a;
while (b)
{
if (b&1)
{
ans = ans*tmp%P;
}
tmp = tmp*tmp%P;
b>>=1;
}
return ans;
}
bool judge (int num) {//求num的所有的质因子
int elem[1000];
int elemNum = 0;
int k = P - 1;
for (int i = 0; i < pNum; ++ i) {
bool flag = false;
while (!(k%prime[i])) {
flag = true;
k /= prime[i];
}
if (flag) {
elem[elemNum ++] = prime[i];
}
if (k == 1) {
break;
}
if (k/prime[i]<prime[i]) {
elem[elemNum ++] = prime[i];
break;
}
}
bool flag = true;
for (int i = 0; i < elemNum; ++ i) {
if (getProduct (num, (P-1)/elem[i], P) == 1) {
flag = false;
break;
}
}
return flag;
}
int main()
{
getPrime();
while (cin >> P)
{
for (int i = 2;;++i)
{
if (judge(i))
{
cout << i<< endl;
break;
}
}
}
return 0;
}
HDU 1402:
随便选一个不太大的模数和他的原根就好了。
#include <cstdio>
#include <cstring>
#include <algorithm>
#include <iostream>
#include <cmath>
#include <map>
#include <vector>
#include <stack>
using namespace std;
#define mod 1004535809LL
#define G 3
#define maxn 400005
long long qpow (long long a, long long b) {
long long ret=1;
while (b) {
if (b&1) ret = (ret*a)%mod;
a = (a*a)%mod;
b >>= 1;
}
return ret;
}
void change (long long y[], int len) {
for(int i = 1, j = len / 2; i < len - 1; i++) {
if(i < j) swap(y[i], y[j]);
int k = len / 2;
while(j >= k) {
j -= k;
k /= 2;
}
if(j < k) j += k;
}
}
void ntt(long long y[], int len, int on) {
change (y, len);
for(int h = 2; h <= len; h <<= 1) {
long long wn = qpow(G, (mod-1)/h);
if(on == -1) wn = qpow(wn, mod-2);
for(int j = 0; j < len; j += h) {
long long w = 1;
for(int k = j; k < j + h / 2; k++) {
long long u = y[k];
long long t = w * y[k + h / 2] % mod;
y[k] = (u + t) % mod;
y[k+h/2] = (u - t + mod) % mod;
w = w * wn % mod;
}
}
}
if(on == -1) {
long long t = qpow (len, mod - 2);
for(int i = 0; i < len; i++)
y[i] = y[i] * t % mod;
}
}
char a[maxn], b[maxn];
long long x1[maxn], x2[maxn];
long long ans[maxn];
int main () {
while (scanf ("%s%s", a, b) == 2) {
int len = 2, l1 = strlen (a), l2 = strlen (b);
while (len < l1*2 || len < l2*2)
len <<= 1;
//cout << len << endl;
for (int i = 0; i < l1; i++) {
x1[i] = a[l1-1-i]-'0';
}
for (int i = l1; i < len; i++)
x1[i] = 0;
for (int i = 0; i < l2; i++) {
x2[i] = b[l2-1-i]-'0';
}
for (int i = l2; i < len; i++)
x2[i] = 0;
ntt(x1, len, 1);
ntt(x2, len, 1);
for (int i = 0; i < len; i++)
x1[i] = x1[i]*x2[i]%mod;
ntt(x1, len, -1);
for (int i = 0; i < len; i++) {
ans[i] = x1[i];
}
for (int i = 0; i < len; i++) {
if (ans[i] >= 10) {
ans[i+1] += ans[i]/10;
ans[i] %= 10;
}
}
len = l1+l2-1;
while (ans[len] <= 0 && len > 0)
len--;
for (int i = len; i >= 0; i--) {
printf ("%lld", ans[i]);
}
printf ("\n");
}
return 0;
}