FFT & NTT 学习 模板

4 篇文章 0 订阅
3 篇文章 0 订阅

参考资料:

算导第 30 章

http://www.gatevin.moe/acm/fft%E7%AE%97%E6%B3%95%E5%AD%A6%E4%B9%A0%E7%AC%94%E8%AE%B0/

http://blog.csdn.net/acdreamers/article/details/39026505


先来两道裸的


hdu 1402 (DFT)


Code:

#include <bits/stdc++.h>
#define PI acos(-1.0)
#define maxn 140010
#define maxl 50010
using namespace std;
inline int max(int a, int b) { return a > b ? a : b; }
struct Complex {
    double real, image;
    Complex(double _real = 0, double _image = 0) : real(_real), image(_image) {}
};
Complex operator + (const Complex& c1, const Complex& c2) { return Complex(c1.real + c2.real, c1.image + c2.image); }
Complex operator - (const Complex& c1, const Complex& c2) { return Complex(c1.real - c2.real, c1.image - c2.image); }
Complex operator * (const Complex& c1, const Complex& c2) { return Complex(c1.real * c2.real - c1.image * c2.image, c1.real * c2.image + c2.real * c1.image); }
ostream& operator << (ostream& out, Complex c) { out << c.real << "+ i * " << c.image << endl;}
Complex a[maxn], b[maxn], A[maxn];
char s1[maxl], s2[maxl];
int ans[maxn];
int rev(int x, int len) {
    int ret = 0, mask = 1;
    for (int i = 0; i < len; ++i) {
        ret <<= 1;
        if (mask & x) ret |= 1;
        mask <<= 1;
    }
    return ret;
}
void dft(Complex* a, int len, int D) {
    int h = (int)((double)log(len) / log(2) + 0.5);
    for (int i = 0; i < len; ++i) {
        A[rev(i, h)] = a[i];
    }
//    for (int i = 0; i < len; ++i) cout << A[i]; cout << "\n";
    for (int s = 1; s <= h; ++s) {
        int m = 1 << s;
        Complex wm = Complex(cos(D * 2 * PI / m), sin(D * 2 * PI / m));
        for (int k = 0; k < len; k += m) {
            Complex w = Complex(1, 0);
            for (int j = 0; j < (m >> 1); ++j) {
                Complex temp = w * A[k + j + (m >> 1)];
                Complex ori = A[k + j];
                A[k + j] = ori + temp;
                A[k + j + (m >> 1)] = ori - temp;
                w = w * wm;
            }
        }
    }
    if (D == -1) {
        for (int i = 0; i < len; ++i) A[i].real /= len, A[i].image /= len;
    }
    for (int i = 0; i < len; ++i) a[i] = A[i];
}
void work() {
    int len1 = strlen(s1), len2 = strlen(s2), len = len1 + len2;
    int n = 1;
    while (n < len) n <<= 1;
    len = n;
    memset(a, 0, sizeof(a)); memset(b, 0, sizeof(b));
    for (int i = 0; i < len1; ++i) a[i] = Complex(s1[len1 - 1 - i] - '0', 0);
    for (int i = 0; i < len2; ++i) b[i] = Complex(s2[len2 - 1 - i] - '0', 0);
    dft(a, len, 1); dft(b, len, 1);
    for (int i = 0; i < len; ++i) a[i] = a[i] * b[i];
    dft(a, len, -1);
    memset(ans, 0, sizeof(ans));
    for (int i = 0; i < len; ++i) ans[i] = (int)(a[i].real + 0.5);
    for (int i = 0; i < len; ++i) {
        ans[i + 1] += ans[i] / 10;
        ans[i] %= 10;
    }
    if (ans[len]) ++len;
    int i = len - 1;
    while (i >= 0 && !ans[i]) --i;
    if (i == -1) { printf("0\n"); return; }
    for (; i >= 0; --i) printf("%d", ans[i]);
    printf("\n");
}
int main() {
    freopen("in.txt", "r", stdin);
    while (scanf("%s%s", s1, s2) != EOF) work();
    return 0;
}


51nod 1028 (NTT)


Code:

#include <bits/stdc++.h>
#define maxn 300010
typedef long long LL;
using namespace std;
LL a[maxn], b[maxn], A[maxn], wn[22];
char s1[maxn], s2[maxn];
const int N = 1 << 18;
const int P = (479 << 21) + 1;
const int G = 3;
const int NUM = 20;
LL poww(LL a, LL b) {
    LL ret = 1;
    while (b) {
        if (b & 1) ret = ret * a % P;
        a = a * a % P;
        b >>= 1;
    }
    return ret;
}
LL pre() {
    for (int i = 0; i < NUM; ++i) {
        int t = 1 << i;
        wn[i] = poww(G, (P - 1) / t);
    }
}
int geta(char* s, LL* a) {
    int len = strlen(s);
    for (int i = 0; i < len; ++i) a[i] = s[len - 1 - i] - '0';
    return len;
}
int rev(int x, int len) {
    int ret = 0, mask = 1;
    for (int i = 1; (1 << i) <= len; ++i) {
        ret <<= 1;
        if (mask & x) ret |= 1;
        mask <<= 1;
    }
    return ret;
}
void ntt(LL* a, int len, int N) {
    int id = 0;
    memset(A, 0, sizeof(A));
    for (int i = 0; i < len; ++i) A[rev(i, len)] = a[i];

    for (int i = 1; (1 << i) <= len; ++i) {
        int m = 1 << i;
        ++id;
        for (int j = 0; j < len; j += m) {
            LL w = 1;
            for (int k = 0; k < (m >> 1); ++k) {
                LL temp = w * A[j + k + (m >> 1)] % P;
                LL ori = A[j + k] % P;
                A[j + k] = (ori + temp) % P;
                A[j + k + (m >> 1)] = (ori - temp + P) % P;
                w = w * wn[id] % P;
            }
        }
    }
//    for (int i = 0; i < len; ++i) printf("%lld", A[i]); printf("\n");

    if (N == -1) {
        for (int i = 1, j = len - 1; i < j; ++i, --j) swap(A[i], A[j]);
        LL inv = poww(len, P - 2);
        for (int i = 0; i < len; ++i) A[i] = A[i] * inv % P;
    }
    for (int i = 0; i < len; ++i) a[i] = A[i];
}
void work() {
    memset(a, 0, sizeof(a)); memset(b, 0, sizeof(b));
    int len1 = geta(s1, a), len2 = geta(s2, b);
    int len = len1 + len2, len0 = 1;
    while (len0 < len) len0 <<= 1;

    ntt(a, len0, 1); ntt(b, len0, 1);
    for (int i = 0; i < len0; ++i) a[i] = a[i] * b[i] % P;
    ntt(a, len0, -1);

    for (int i = 0; i < len; ++i) {
        a[i + 1] += a[i] / 10;
        a[i] %= 10;
    }
    if (a[len]) ++len;

    int p = len - 1;
    while (p >= 0 && a[p] == 0) --p;
    if (p == -1) { printf("0\n"); return; }
    for (int i = p; i >= 0; --i) printf("%lld", a[i]); printf("\n");
}
int main() {
    pre();
    while (scanf("%s%s", s1, s2) != EOF) work();
    return 0;
}



  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值