快速傅里叶变换(FFT)与快速数论变换(NTT)+例题

多项式的系数表示法

考虑多项式 A(x)=i=0naixi A ( x ) = ∑ i = 0 n a i x i ,其中 {a0,a1,,an} { a 0 , a 1 , … , a n } 被称为多项式 A(x) A ( x ) 的系数向量。每个多项式都有唯一的系数向量,每个系数向量都对应唯一的多项式。

多项式的点值表示法

我们可以把多项式 A(x) A ( x ) 看做是一个 n n 次函数,我们可以取n+1个不同的值 b0,b1,,bn b 0 , b 1 , ⋯ , b n 带入分别求出 n+1 n + 1 个多项式的值 c0,c1,,cn c 0 , c 1 , ⋯ , c n 。可以看出,从系数表示法到点值表示法是唯一的,而点值表示法在系数未知的时候可以看做是一个 n+1 n + 1 元一次方程组,可以解出唯一系数。因此点值表示与多项式也一一对应。

多项式乘法

C(x)=A(x)B(x)=i=0nj=0maibjxi+j C ( x ) = A ( x ) B ( x ) = ∑ i = 0 n ∑ j = 0 m a i b j x i + j ,其中 A(x),B(x) A ( x ) , B ( x ) 分别是 n,m n , m 次多项式, A(x),B(x) A ( x ) , B ( x ) 的系数向量是 a,b a → , b →
容易发现,用系数表示法使两个向量相乘是 O(n2) O ( n 2 ) 的复杂度,那如何才能优化呢?
考虑两个点值表示的多项式相乘,易发现此时只需要把两个多项式的对应点值相乘即可,复杂度为 O(n) O ( n ) 。但是如何把系数表示转化为点值表示,再转化回来呢?。
如果我们选取 n+1 n + 1 个值暴力代入,复杂度仍然为 O(n2) O ( n 2 ) ,甚至转化回来的时候会用到 O(n3) O ( n 3 ) 的高斯消元,难道点值表示就没有任何可取之处了吗?
因此,一种算法叫做“快速傅里叶变换”诞生了,它可以在 O(nlogn) O ( n l o g n ) 的时间内完成上述两部转化。

快速傅里叶变换(FFT)

单位根

xn=1 x n = 1 ,则 x x 被称为n次单位根。 n n 次单位根共有n个,分别形如 e2kπin,0k<n,kZ e 2 k π i n , 0 ≤ k < n , k ∈ Z ,注意这里的 i i 是虚数单位。为什么呢?

(e2kπin)n=e2kπi=(e2πi)k=1k=1

倒数第二步使用了欧拉公式 exi=cos(x)+isin(x) e x i = c o s ( x ) + i · s i n ( x ) ,因此我们也可以得到 e2kπin=cos(2kπn)+isin(2kπn) e 2 k π i n = c o s ( 2 k π n ) + i · s i n ( 2 k π n ) ,这样我们就可以用平常的复数表示去计算单位根了。
为了方便,我们记 ωn=e2πin ω n = e 2 π i n ,则 n n 个单位根分别为ωn0,ωn1,,ωnn1

单位根的性质

在讨论性质时,均假定 n n 为偶数。

引理1

单位根具有对称性,即ωnk=ωnk+n2。这个定理是比较好证明的,因为有

ωn2n=cos(π)+isin(π)=1ωk+n2n=ωn2nωkn=ωkn ω n n 2 = c o s ( π ) + i · s i n ( π ) = − 1 , ∴ − ω n k + n 2 = − ω n n 2 · ω n k = ω n k

引理2

ω2kn=ωkn20k<n2kZ ω n 2 k = ω n 2 k , 0 ≤ k < n 2 , k ∈ Z

这个利用性质也是很好证明的。
ω2kn=e4kπin=e2kπin2=ωkn2 ω n 2 k = e 4 k π i n = e 2 k π i n 2 = ω n 2 k

FFT算法

上面我们说了那么多,究竟是要干什么呢?没错!把单位根当做数值带入多项式,求出多项式的点值表示。但是到此为止,我们的复杂度还是 O(n2) O ( n 2 ) 的,甚至由于涉及到复数运算,常数只会比原来更大。于是我们要好好利用单位根的性质进行简化。接下来假设 n n 是2的整数次幂。
考虑关于单位根的n1次多项式 A(ωkn) A ( ω n k ) ,先暴力计算(注意这里的 i i 不是虚数啦):

A(ωnk)=i=0n1aiωnki

FFT接下来做的事情是把这个东西按照奇偶项分类:

A(ωkn)=i=0n21a2iω2kin+i=0n21a2i+1ωk(2i+1)n A ( ω n k ) = ∑ i = 0 n 2 − 1 a 2 i ω n 2 k i + ∑ i = 0 n 2 − 1 a 2 i + 1 ω n k ( 2 i + 1 )

利用单位根性质化简,可以得到:
A(ωkn)=i=0n21a2iωkin2+ωkni=0n21a2i+1ωkin2 A ( ω n k ) = ∑ i = 0 n 2 − 1 a 2 i ω n 2 k i + ω n k ∑ i = 0 n 2 − 1 a 2 i + 1 ω n 2 k i

于是,我们惊奇的发现,按照奇偶项分类之后,我们把有 n n 个要带入的值划分成了2个需要带入n2个值的子问题!再加上引理1,我们可以总结出分治合并过程:
A(ωkn)=i=0n21a2iωkin2+ωkni=0n21a2i+1ωkin20k<n2kZ A ( ω n k ) = ∑ i = 0 n 2 − 1 a 2 i ω n 2 k i + ω n k ∑ i = 0 n 2 − 1 a 2 i + 1 ω n 2 k i , 0 ≤ k < n 2 , k ∈ Z

A(ωk+n2n)=i=0n21a2iωkin2ωkni=0n21a2i+1ωkin20k<n2kZ A ( ω n k + n 2 ) = ∑ i = 0 n 2 − 1 a 2 i ω n 2 k i − ω n k ∑ i = 0 n 2 − 1 a 2 i + 1 ω n 2 k i , 0 ≤ k < n 2 , k ∈ Z

于是,快速傅里叶变换可以在 T(n)=2T(n2)+O(n)=O(nlogn) T ( n ) = 2 T ( n 2 ) + O ( n ) = O ( n l o g n ) 的时间内完成!

傅里叶逆变换

我们已经可以在 O(nlogn) O ( n l o g n ) 的时间内把多项式的系数表示转化为点值表示,但是如何把点值表示转化为系数表示呢?
也就是我们需要解出一个 n n 元一次方程组,考虑把它化为矩阵形式:

[ωn0ωn0ωn0ωn0ωn1ωnn1ωn0ωnn1ωn(n1)2][a0a1an1]=[A(ωn0)A(ωn1)A(ωnn1)]

其中 {a0,a1,,an1} { a 0 , a 1 , ⋯ , a n − 1 } 是多项式 A A 的系数向量。如果我们能够求出来左边矩阵的逆矩阵,问题就会好办很多。考虑如下的矩阵:
[ωn0ωn0ωn0ωn0ωn1ωn(n1)ωn0ωn(n1)ωn(n1)2]

也就是原矩阵中每个数取倒数,考虑这两个矩阵相乘,设原矩阵为 P P ,上面的矩阵为Q

(PQ)i,j=k=0n1Pi,kQk,j=k=0n1ωiknωkjn=k=0n1ωk(ij)n ( P Q ) i , j = ∑ k = 0 n − 1 P i , k · Q k , j = ∑ k = 0 n − 1 ω n i k · ω n − k j = ∑ k = 0 n − 1 ω n k ( i − j )

于是可以分类讨论了。

  1. i=j i = j ,则 (PQ)i,j=n ( P Q ) i , j = n
  2. ij i ≠ j ,则原式可以看做是等比数列的形式,原式 =ωijnωnn1ωn1=0 = ω n i − j · ω n n − 1 ω n − 1 = 0

于是我们得到了一个结论:两个矩阵的乘积除主对角线为 n n ,其它位置全部为0.这可以看做是n倍的单位矩阵,也就是说,我们把 Q Q 矩阵和右边的点值向量相乘,就可以得到系数向量。但是这样的复杂度仍然是O(n2)的。
注意到 ωkn ω n − k 实际上仍然是 n n 次单位根!证明:

(ωnk)n=(e2kπin)n=(e2πi)k=(cos(2π)+i·sin(2π))k=1

因此它仍然满足单位根的所有性质!于是我们可以把FFT时的 ωkn ω n k ωkn ω n − k 代替,再跑一边FFT,得出来的 n n 个数字除以n就是原多项式的系数向量!
于是,傅里叶逆变换也可以在 O(nlogn) O ( n l o g n ) 的时间内完成。
到此时,FFT的递归实现应该也比较好理解了。
注意上面假设的都是n为2的整数次幂!如果不足需要补齐!

FFT的迭代实现

迭代过程
(这里容许我盗一下图……)

我们来观察一下最后一步时所有数字的顺序。考虑把所有二进制串反过来,比如1000变为0001,我们会发现最后一步时fft的顺序就是从0到 n1 n − 1 !(最后两个似乎画反了……)也就是说,原串中第 i i 个数到fft的最后一步时就变成了第rev(i)个数,其中 rev r e v 函数表示翻转一个数的二进制表示。只要我们按照这个排好序,一步一步合并上去就行了!
于是我们从头到尾扫一遍数组,假设当前扫到第 i i 个数,只要rev(i)>i,我们就可以交换 rev(i)i r e v ( i ) 和 i 的值,这样最后得到的数组就是fft最后一步的数组!然后就可以很方便地迭代实现fft了!

const int maxn = 1 << 18;
const long double PI = (long double)3.14159265358979323846;
struct Complex{
    long double r, i;
    Complex(){r = i = 0;}
    Complex(long double a, long double b){r = a, i = b;}
    Complex operator+(const Complex &c) const 
        {return Complex(r + c.r, i + c.i);}
    Complex operator-(const Complex &c) const 
        {return Complex(r - c.r, i - c.i);}
    Complex operator*(const Complex &c) const 
        {return Complex(r * c.r - i * c.i, i * c.r + r * c.i);}
} A[maxn];
void rader(Complex *a, int n){//倒位序
    for(int i = 1, j = n >> 1; i < n - 1; i++){
        if(i < j) swap(a[i], a[j]);//j=rev(i)
        int k = n >> 1;
        for(; j >= k; k >>= 1) j -= k;
        if(j < k) j += k;//反向二进制加法
    }
}
void fft(Complex *a, int n, int rev){
    rader(a, n);
    for(int h = 2; h <= n; h <<= 1){
        int hh = h >> 1;
        Complex wn = Complex(cosl(PI / hh), rev * sinl(PI / hh));
        for(int i = 0; i < n; i += h){
            Complex *ta = a + i, *tb = a + i + hh, w = Complex(1, 0);
            for(int j = 0; j < hh; ++j, ++ta, ++tb){
                Complex x = *ta, y = w * *tb;
                *ta = x + y, *tb = x - y, w = w * wn;
            }
        }
    }
    if(rev == -1) for(int i = 0; i < n; i++)
        a[i] = a[i] * Complex(1.0 / n, 0);
}

一个小优化:正常来说我们都是进行两边DFT,然后点值乘法,再IDFT,但实际上在用FFT算卷积的时候可以去掉一个DFT。比如计算a和b的卷积,我们把需要进行FFT的复数数组的实数部分设置为a,虚数部分设置为b,然后DFT一次,计算自己的平方卷积,再IDFT出来,虚数部分结果除以2就是原来的答案。

快速数论变换(NTT)

FFT已经可以在 O(nlogn) O ( n l o g n ) 的时间内完成多项式点值和系数表示之间的转换,但是在OI中,我们经常要求的是对于某个数求模的结果,这样FFT的精度显然不够了。
考虑在模运算下定义单位根。设模数为质数 p p ,那么它的原根gp1n实际上和 ωn ω n 等价。为什么呢?考虑单位根的几个性质:
1. n n 个单位根互不相等,根据原根定义,原根的0次幂到p1次幂都不相等,上面那 n n 个值自然不相等。
2.单位根的n次幂等于1,这个根据费马小定理,任意与 p p 互质正整数的p1次幂都为1,因此对于上面的也成立。
3.单位根的对称性。证明:

(gp12)2gp11(mod p)gi≢gj(mod p)(ij,0i,j<p)gp121(mod p) ∵ ( g p − 1 2 ) 2 ≡ g p − 1 ≡ 1 ( m o d   p ) 且 g i ≢ g j ( m o d   p ) ( i ≠ j , 0 ≤ i , j < p ) ∴ g p − 1 2 ≡ − 1 ( m o d   p )

证毕。
4.单位根的引理2,这个是显然成立的。
于是我们也可以在 O(nlogn) O ( n l o g n ) 的时间内计算出模运算意义下的卷积了(但似乎常数不小啊……)
注意, p1 p − 1 必须是 n n 的倍数!如9982443531=223×7×17!!
先附上找质数原根的代码(998244353和1004535809( =221×479 = 2 21 × 479 )的原根都是3)

#include <bits/stdc++.h>
using namespace std;
typedef long long ll;

ll modmul(ll a, ll b, ll mod){
    ll res = 0;
    for(; b; b >>= 1){
        if(b & 1) res = (res + a) % mod;
        a = (a + a) % mod;
    }
    return res;
}
ll modpow(ll a, ll b, ll mod){
    ll res = 1;
    for(; b; b >>= 1){
        if(b & 1) res = modmul(res, a, mod) % mod;
        a = modmul(a, a, mod) % mod;
    }
    return res;
}
vector<ll> vec;
int main(){
    ll mod;
    while(~scanf("%lld", &mod)){
        vec.clear();
        ll p = mod - 1;
        for(ll i = 2; i * i <= p; i++) if(p % i == 0) {
            vec.push_back(i);
            while(p % i == 0) p /= i;
        }
        if(p > 1) vec.push_back(p);
        int sz = vec.size();
        for(int i = 2;; i++){
            int flag = 1;
            for(int j = 0; j < sz; j++)
                if(modpow(i, (mod - 1) / vec[j], mod) == 1){flag = 0; break;}
            if(flag == 1){printf("%d\n", i); break;}
        }
    }
    return 0;
}

再附上NTT的板子~

typedef long long ll;
const int mod = 998244353, G = 3;
ll modpow(ll a, int b){
    ll res = 1;
    for(; b; b >>= 1){
        if(b & 1) res = res * a % mod;
        a = a * a % mod;
    }
    return res;
}
void rader(ll *a, int n){
    for(int i = 1, j = n >> 1; i < n - 1; i++){
        if(i < j) swap(a[i], a[j]);
        int k = n >> 1;
        for(; j >= k; k >>= 1) j -= k;
        if(j < k) j += k;
    }
}
void NTT(ll *a, int n, int rev){
    rader(a, n);
    for(int h = 2; h <= n; h <<= 1){
        int hh = h >> 1, wn = modpow(G, rev ? mod - 1 - (mod - 1) / h : (mod - 1) / h);
        for(int i = 0; i < n; i += h){
            ll w = 1;
            for(int j = i; j < i + hh; j++){
                int x = a[j], y = w * a[j + hh] % mod;
                a[j] = (x + y) % mod;
                a[j + hh] = (x - y + mod) % mod;
                w = w * wn % mod;
            }
        }
    }
    if(rev){
        int inv = modpow(n, mod - 2);
        for(int i = 0; i < n; i++) a[i] = a[i] * inv % mod;
    }
}

此时我们就可以愉快的做题啦!

例题

BZOJ4555: [Tjoi2016&Heoi2016]求和

原题链接
题意:求如下函数的值:

f(n)=i=0nj=0iS(i,j)2jj! f ( n ) = ∑ i = 0 n ∑ j = 0 i S ( i , j ) · 2 j · j !

其中, S(i,j) S ( i , j ) 指第二类斯特林数。
首先发现,对于 j>iS(i,j)=0 j > i , S ( i , j ) = 0 ,可以考虑扩大 j j 的取值范围,使得它与i无关。
再考虑暴力展开第二类斯特林数(如果不造第二类斯特林数展开的,戳 这里,翻到靠近下面的部分,有公式+证明),则
f(n)=i=0nj=0n2jk=0j(1)jk(jk)ki=j=0n2jj!k=0j(1)jk(jk)!ni=0kik! f ( n ) = ∑ i = 0 n ∑ j = 0 n 2 j ∑ k = 0 j ( − 1 ) j − k ( j k ) k i = ∑ j = 0 n 2 j j ! ∑ k = 0 j ( − 1 ) j − k ( j − k ) ! · ∑ i = 0 n k i k !

右边的 ni=0ki ∑ i = 0 n k i 实际上可以看做是等比数列求和,对于 k=0,1 k = 0 , 1 的情况特判一下,这样右边就成了一个卷积的形式,直接NTT就行了,复杂度 O(nlogn) O ( n l o g n )

#include <bits/stdc++.h>
using namespace std;
typedef long long ll;

const int maxn = 100005, mod = 998244353, G = 3;
int modpow(ll a, int b){
    ll res = 1;
    for(; b; b >>= 1){
        if(b & 1) res = res * a % mod;
        a = a * a % mod;
    }
    return res;
}
void rader(ll *a, int n){
    for(int i = 1, j = n >> 1; i < n - 1; i++){
        if(i < j) swap(a[i], a[j]);
        int k = n >> 1;
        for(; j >= k; k >>= 1) j -= k;
        if(j < k) j += k;
    }
}
void NTT(ll *a, int n, int rev){
    rader(a, n);
    for(int h = 2; h <= n; h <<= 1){
        int hh = h >> 1, wn = modpow(G, rev ? mod - 1 - (mod - 1) / h : (mod - 1) / h);
        for(int i = 0; i < n; i += h){
            ll w = 1;
            for(int j = i; j < i + hh; j++){
                int x = a[j], y = w * a[j + hh] % mod;
                a[j] = (x + y) % mod;
                a[j + hh] = (x - y + mod) % mod;
                w = w * wn % mod;
            }
        }
    }
    if(rev){
        int inv = modpow(n, mod - 2);
        for(int i = 0; i < n; i++) a[i] = a[i] * inv % mod;
    }
}
ll fact[maxn], revf[maxn], rev[maxn], A[1 << 18], B[1 << 18];
int main(){
    int n; scanf("%d", &n);
    rev[1] = fact[0] = revf[0] = 1;
    for(int i = 1; i <= n; i++){
        if(i > 1) rev[i] = mod - (ll)mod / i * rev[mod % i] % mod;
        revf[i] = revf[i - 1] * rev[i] % mod;
        fact[i] = fact[i - 1] * i % mod;
    }
    int tn = 1;
    while(tn < 2 * n + 1) tn <<= 1;
    for(int i = 0; i <= n; i++){
        A[i] = i & 1 ? mod - revf[i] : revf[i];
        if(i > 0) B[i] = (i > 1 ? (modpow(i, n + 1) - 1) * rev[i - 1] % mod : n + 1) * revf[i] % mod;
        else B[i] = 1;
    }
    NTT(A, tn, 0), NTT(B, tn, 0);
    for(int i = 0; i < tn; i++) A[i] = A[i] * B[i] % mod;
    NTT(A, tn, 1);
    ll res = 0;
    for(int i = 1, j = 0; j <= n; j++){
        res = (fact[j] * i % mod * A[j] + res) % mod;
        i = i * 2 % mod;
    }
    printf("%lld\n", res);
    return 0;
}
BZOJ4836: [Lydsy1704月赛]二元运算

原题链接
像这种区别于i,j的卷积可以使用CDQ分治+NTT处理。先考虑把一整段剖成左右两块,分别递归处理,然后再使用NTT计算左边的x对右边的y的贡献。但是这道题右边x对左边y的贡献是减法卷积,我们可以把左边的多项式翻转,再求卷积,理论复杂度为 O(nlog2n) O ( n l o g 2 n ) ,但似乎常数超级大,而且明明在本机跑得比别人快2倍,在BZOJ上却T掉……
不管了,假装自己过了
还是放上我自己常数巨大的代码吧……

#include <bits/stdc++.h>
using namespace std;
typedef long long ll;

const int maxn = 1 << 17;
const double PI = 3.1415926535898;
int A[maxn], B[maxn], rev[maxn], n, m, Q, T;
ll C[maxn];
struct Complex{
    double r, i;
    Complex(){r = i = 0.0;}
    Complex(double a, double b){r = a, i = b;}
    Complex operator+(const Complex &c) const {
        return Complex(r + c.r, i + c.i);
    }
    Complex operator-(const Complex &c) const {
        return Complex(r - c.r, i - c.i);
    }
    Complex operator*(const Complex &c) const {
        return Complex(r * c.r - i * c.i, r * c.i + i * c.r);
    }

} AA[maxn], BB[maxn], R[maxn];
void FFT(Complex *a, int n, int r){
    for(int i = 0; i < n; i++) if(i < rev[i]) swap(a[i], a[rev[i]]);
    for(int h = 2; h <= n; h <<= 1){
        int hh = h >> 1;
        Complex wn = Complex(cos(PI / hh), sin(PI / hh));
        if(r) wn.i = -wn.i;
        Complex *ta = a, *tb = a + hh;
        for(int i = 0; i < n; i += h){
            Complex w = Complex(1, 0);
            for(int j = 0; j < hh; ++j, ++ta, ++tb){
                Complex x = *ta, y = w * *tb;
                *ta = x + y, *tb = x - y;
                w = w * wn;
            }
            ta += hh, tb += hh;
        }
    }
    if(r) for(int i = 0; i < n; i++) a[i].r = a[i].r / n;
}
void mul(Complex *a, Complex *b, int n){
    if(n <= 32){
        for(int i = 0; i < n; i++) R[i] = Complex(0, 0);
        for(int i = 0; i < n >> 1; i++)
        for(int j = 0; j < n >> 1; j++)
            R[i + j] = R[i + j] + a[i] * b[j];
        for(int i = 0; i < n; i++) a[i] = R[i];
    } else {
        FFT(a, n, 0), FFT(b, n, 0);
        for(int i = 0; i < n; i++) a[i] = a[i] * b[i];
        FFT(a, n, 1);
    }
}
void cdq(int l, int r){
    if(l == r - 1){C[0] += A[l] * B[l]; return;}
    int mid = (l + r) >> 1, hlen = r - l, len = hlen << 1;
    int t = __builtin_ctz(hlen);
    for(int i = 0; i < len; i++) rev[i] = rev[i >> 1] >> 1 | (i & 1) << t;
    for(int i = 0; i < len; i++) AA[i] = BB[i] = Complex(0, 0);
    for(int i = l; i < mid; i++) AA[i - l].r = A[i];
    for(int i = mid; i < r; i++) BB[i - mid].r = B[i];
    mul(AA, BB, len);
    for(int i = 0; i < len; i++) C[i + l + mid] += (int)(AA[i].r + 0.1);
    for(int i = 0; i < len; i++) AA[i] = BB[i] = Complex(0, 0);
    for(int i = mid; i < r; i++) AA[i - mid].r = A[i];
    for(int i = l; i < mid; i++) BB[i - l].r = B[mid - i + l - 1];
    mul(AA, BB, len);
    for(int i = 0; i < len; i++) C[i + 1] += (int)(AA[i].r + 0.1);
    cdq(l, mid), cdq(mid, r);
}
const int maxr = 10000000;
char str[maxr], prt[maxr]; int rpos, ppos;
char readc(){
    if(!rpos) fread(str, 1, maxr, stdin);
    char c = str[rpos++];
    if(rpos == maxr) rpos = 0;
    return c;
}
int read(){
    int x; char c;
    while((c = readc()) < '0' || c > '9');
    x = c - '0';
    while((c = readc()) >= '0' && c <= '9') x = x * 10 + c - '0';
    return x;
}
void print(ll x){
    if(x){
        static char sta[20];
        int tp = 0;
        for(; x; x /= 10) sta[tp++] = x % 10 + '0';
        while(tp > 0) prt[ppos++] = sta[--tp];
    } else prt[ppos++] = '0';
    prt[ppos++] = '\n';
}
int main(){
    for(T = read(); T--;){
        n = read(), m = read(), Q = read();
        int mx = 0, N = 1;
        memset(A, 0, sizeof(A));
        memset(B, 0, sizeof(B));
        memset(C, 0, sizeof(C));
        for(int i = 0; i < n; i++){
            int t = read();
            mx = max(mx, t);
            ++A[t];
        }
        for(int i = 0; i < m; i++){
            int t = read();
            mx = max(mx, t);
            ++B[t];
        }
        while(N <= mx) N <<= 1;
        cdq(0, N);
        while(Q--) print(C[read()]);
    }
    fwrite(prt, 1, ppos, stdout);
    return 0;
}
例题BZOJ4827: [Hnoi2017]礼物

原题链接
这道题似乎比较水啊,考虑如何算出确定两个装饰的亮度时不同旋转位置的差异值。我们把 ni=1(xiyi)2 ∑ i = 1 n ( x i − y i ) 2 变成 ni=1xi+ni=1yi2ni=1xiyi ∑ i = 1 n x i + ∑ i = 1 n y i − 2 ∑ i = 1 n x i y i ,会发现前面两个是常数,后面一个是乘法。考虑把乘法化成卷积的形式,把一个手环上所有装饰的信息复制一遍接在后面(对没错就跟普通处理环的方法一样),然后翻转另一个手环的亮度信息,这样做一个卷积,就可以得到旋转不同的角度得到的差异值了。
再考虑如何计算最小差异值。如果我们确定一个亮度时算出了不同位置的差异值,并且找出了最小值所在的位置,那么无论我把一个手环的亮度整体加上多少,卷积后取最小值的位置必然不会改变,因为整体加 k k 可以看做是所有角度上的值都加上了ki=1nyi,然后前面两个常数的变化其实可以暴力算,复杂度 O(nlogn+nm) O ( n l o g n + n m ) 就解决了。
其实再深入一点,差异值关于亮度整体的变化值是一个凹函数,因此可以三分优化到 O(n(logn+logm)) O ( n ( l o g n + l o g m ) )

#include <bits/stdc++.h>
using namespace std;
typedef long long ll;

const int maxn = 1 << 18;
const double PI = 3.14159265358979323846;
int A[maxn], B[maxn], n, m;
struct Complex{
    double r, i;
    Complex(){r = i = 0.0;}
    Complex(double a, double b){r = a, i = b;}
    Complex operator+(const Complex &c) const {
        return Complex(r + c.r, i + c.i);
    }
    Complex operator-(const Complex &c) const {
        return Complex(r - c.r, i - c.i);
    }
    Complex operator*(const Complex &c) const {
        return Complex(r * c.r - i * c.i, i * c.r + r * c.i);
    }
} AA[maxn];
void rader(Complex *a, int n){
    for(int i = 1, j = n >> 1; i < n - 1; i++){
        if(i < j) swap(a[i], a[j]);
        int k = n >> 1;
        for(; j >= k; k >>= 1) j -= k;
        if(j < k) j += k;
    }
}
void fft(Complex *a, int n, int rev){
    rader(a, n);
    for(int h = 2; h <= n; h <<= 1){
        int hh = h >> 1;
        Complex wn = Complex(cos(PI / hh), rev * sin(PI / hh));
        for(int i = 0; i < n; i += h){
            Complex w = Complex(1, 0);
            for(int j = i; j < i + hh; j++){
                Complex x = a[j], y = w * a[j + hh];
                a[j] = x + y, a[j + hh] = x - y;
                w = w * wn;
            }
        }
    }
    if(rev == -1) for(int i = 0; i < n; i++)
        a[i] = a[i] * Complex(1.0 / n, 0);
}
int calc(int i, int mn, int ss){
    int sum = 0;
    for(int j = 0; j < n; j++) sum += (A[j] + i) * (A[j] + i);
    return sum - 2 * (mn + i * ss);
}
int main(){
    scanf("%d%d", &n, &m);
    for(int i = 0; i < n; i++){
        scanf("%d", A + i);
        A[i] = A[i + n] = A[i];
        AA[i].r = AA[i + n].r = A[i];
    }
    for(int i = 0; i < n; i++){
        scanf("%d", B + i);
        AA[n - i - 1].i = B[i];
    }
    int len = 1;
    while(len < 3 * n) len <<= 1;
    fft(AA, len, 1);
    for(int i = 0; i < len; i++) AA[i] = AA[i] * AA[i];
    fft(AA, len, -1);
    int mn = INT_MIN, res = INT_MAX, ini = 0, ss = 0;
    for(int i = n - 1; i < 2 * n - 1; i++){
        mn = max(mn, int(AA[i].i / 2 + 0.5));
        //printf("%d %d\n", i, int(AA[i].i / 2 - 0.5));
    }
    for(int i = 0; i < n; i++) ini += B[i] * B[i], ss += B[i];
    int l = -m, r = m;
    while(l + 3 <= r){
        int len = (r - l + 1) / 3;
        int m1 = l + len, m2 = r - len;
        if(calc(m1, mn, ss) < calc(m2, mn, ss)) r = m2;
        else l = m1;
    }
    for(int i = l; i <= r; i++)
        res = min(res, ini + calc(i, mn, ss));
    printf("%d\n", res);
    return 0;
}

上述代码便使用了以前说的FFT计算时的小优化,两次FFT即可计算出卷积。

例题 洛谷P4491 [HAOI2018]染色

原题链接
其实这道题并不难,主要是推公式的时候一定要细心细心再细心!!
首先我们考虑枚举有多少种颜色恰好出现了s次,再令 g(i,j) g ( i , j ) 表示 i i 种颜色填入j个格子且没有颜色出现s次的方案数。然后对于原题,我们会发现最多只会有 ns ⌊ n s ⌋ 种颜色出现恰好s次,于是令 N=min(ns,m) N = m i n ( ⌊ n s ⌋ , m ) ,可以得到:

ans=i=0Nwi(mi)Aisn(s!)ig(mi,nis) a n s = ∑ i = 0 N w i ( m i ) A n i s ( s ! ) i g ( m − i , n − i s )

再考虑如何计算g。我们可以枚举有恰好多少种颜色出现了s次,然后容斥一下:
g(i,j)=k=0min(js,i)(1)k(ik)Aksj(s!)k(ik)jks g ( i , j ) = ∑ k = 0 m i n ( ⌊ j s ⌋ , i ) ( − 1 ) k ( i k ) A j k s ( s ! ) k ( i − k ) j − k s

带入到原式中,并且把组合数展开:
ans=i=0Nwi(mi)Aisn(s!)ij=0Ni(1)j(mij)Ajsnis(s!)j(mij)nisjs a n s = ∑ i = 0 N w i ( m i ) A n i s ( s ! ) i ∑ j = 0 N − i ( − 1 ) j ( m − i j ) A n − i s j s ( s ! ) j ( m − i − j ) n − i s − j s

=i=0Nwim!n!(mi)!i!(nis)!(s!)ij=0Ni(1)j(mi)!(nis)!(mij)nisjsj!(mij)!(nisjs)!(s!)j = ∑ i = 0 N w i m ! n ! ( m − i ) ! i ! ( n − i s ) ! ( s ! ) i ∑ j = 0 N − i ( − 1 ) j ( m − i ) ! ( n − i s ) ! ( m − i − j ) n − i s − j s j ! ( m − i − j ) ! ( n − i s − j s ) ! ( s ! ) j

我们把类似的项合并,并消去一些相同的项,可以得到:
ans=m!n!i=0Nwii!j=0Ni(1)j(mij)nisjsj!(mij)!(nisjs)!(s!)i+j a n s = m ! n ! ∑ i = 0 N w i i ! ∑ j = 0 N − i ( − 1 ) j ( m − i − j ) n − i s − j s j ! ( m − i − j ) ! ( n − i s − j s ) ! ( s ! ) i + j

发现出现了较多的 i+j i + j ,并且根据它们的取值范围, i+jN i + j ≤ N ,这就很像卷积了!来试试看枚举 i+j i + j ……
ans=m!n!i=0N(mi)nis(mi)!(nis)!(s!)ij=0iwjj!(1)ij(ij)! a n s = m ! n ! ∑ i = 0 N ( m − i ) n − i s ( m − i ) ! ( n − i s ) ! ( s ! ) i ∑ j = 0 i w j j ! · ( − 1 ) i − j ( i − j ) !

好了,到此为止,右边已经是一个卷积的形式了,直接NTT即可,复杂度 O(nlogn) O ( n l o g n )

#include <bits/stdc++.h>
using namespace std;
typedef long long ll;

const int maxn = 100005, maxt = 1 << 18, mod = 1004535809, G = 3;
const int maxmn = 10000005;
ll fact[maxmn], rev[maxmn], A[maxt], B[maxt], val[maxn], n, m, N, S;
ll modpow(ll a, int b){
    ll res = 1;
    for(; b; b >>= 1){
        if(b & 1) res = res * a % mod;
        a = a * a % mod;
    }
    return res;
}
void rader(ll *a, int n){
    for(int i = 1, j = n >> 1; i < n - 1; i++){
        if(i < j) swap(a[i], a[j]);
        int k = n >> 1;
        for(; j >= k; k >>= 1) j -= k;
        if(j < k) j += k;
    }
}
void NTT(ll *a, int n, int rev){
    rader(a, n);
    for(int h = 2; h <= n; h <<= 1){
        int hh = h >> 1, wn = modpow(G, rev ? mod - 1 - (mod - 1) / h : (mod - 1) / h);
        for(int i = 0; i < n; i += h){
            ll w = 1;
            for(int j = i; j < i + hh; j++){
                int x = a[j], y = w * a[j + hh] % mod;
                a[j] = (x + y) % mod;
                a[j + hh] = (x - y + mod) % mod;
                w = w * wn % mod;
            }
        }
    }
    if(rev){
        int inv = modpow(n, mod - 2);
        for(int i = 0; i < n; i++) a[i] = a[i] * inv % mod;
    }
}
int main(){
    scanf("%d%d%d", &n, &m, &S);
    for(int i = 0; i <= m; i++) scanf("%lld", val + i);
    N = min(n / S, m);
    fact[0] = rev[0] = 1;
    int mn = max(n, m);
    for(int i = 1; i <= mn; i++)
        fact[i] = fact[i - 1] * i % mod;
    rev[mn] = modpow(fact[mn], mod - 2);
    for(int i = mn - 1; i > 0; i--)
        rev[i] = rev[i + 1] * (i + 1) % mod;
    for(int i = 0; i <= N; i++){
        A[i] = val[i] * rev[i] % mod;
        B[i] = i & 1 ? mod - rev[i] : rev[i];
    }
    int tn = 1;
    while(tn < 2 * (N + 1)) tn <<= 1;
    NTT(A, tn, 0), NTT(B, tn, 0);
    for(int i = 0; i < tn; i++) A[i] = A[i] * B[i] % mod;
    NTT(A, tn, 1);
    ll sf = 1, res = 0;
    for(int i = 1; i <= S; i++) sf = sf * i % mod;
    for(int i = 0; i <= N; i++){
        ll t = fact[m] * fact[n] % mod * modpow(m - i, n - i * S) % mod;
        t = t * rev[m - i] % mod * rev[n - i * S] % mod * modpow(sf, mod - 1 - i) % mod;
        res = (res + t * A[i]) % mod;
    }
    printf("%lld\n", res);
    return 0;
}
例题BZOJ3992[SDOI2015]序列统计

原题链接
观察到n很大,如果用矩阵乘法的话m又过大了,考虑用倍增(快速幂)解决。
考虑dp。设 f[i][j] f [ i ] [ j ] 表示当前已经选了i个数,得到的乘积模m为j的选取方案数。如果我们能够快速合并 f[a],f[b] f [ a ] , f [ b ] 得到 f[a+b] f [ a + b ] 的值,这道题就解决了。由于

f[a+b][i]=jki(mod m)f[a][j]×f[b][k] f [ a + b ] [ i ] = ∑ j k ≡ i ( m o d   m ) f [ a ] [ j ] × f [ b ] [ k ]

观察到这个东西和卷积长得很像,但是卷积是加,它是乘法。于是我们可以利用题目条件把乘法转化为加法。
令m的原根为g,则0到m-1都可以表示成g的幂次。定义离散对数 lg(i)使glg(i)i(mod m) l g ( i ) 使 g l g ( i ) ≡ i ( m o d   m ) ,并且把 f[i][j] f [ i ] [ j ] 的定义改为表示当前已经选了i个数,得到的乘积模m为 gj g j ,那么上述状态转移方程就可以变为:
f[a+b][i]=j+ki(mod m)f[a][j]×f[b][k] f [ a + b ] [ i ] = ∑ j + k ≡ i ( m o d   m ) f [ a ] [ j ] × f [ b ] [ k ]

答案就是 f[n][lg(x)] f [ n ] [ l g ( x ) ] 。中间算卷积时的模运算可以等卷完了在处理,总复杂度 O(mlognlogm) O ( m l o g n l o g m )

#include <bits/stdc++.h>
using namespace std;
typedef long long ll;

const int maxn = 1 << 14, mod = 1004535809, G = 3;
ll modpow(ll a, ll b, ll p = mod){
    ll res = 1;
    for(; b; b >>= 1){
        if(b & 1) res = res * a % p;
        a = a * a % p;
    }
    return res;
}
void rader(ll *a, int n){
    for(int i = 1, j = n >> 1; i < n - 1; i++){
        if(i < j) swap(a[i], a[j]);
        int k = n >> 1;
        for(; j >= k; k >>= 1) j -= k;
        if(j < k) j += k;
    }
}
void ntt(ll *a, int n, int rev){
    rader(a, n);
    for(int h = 2; h <= n; h <<= 1){
        int hh = h >> 1, wn = modpow(G, (mod - 1 + rev * (mod - 1) / h) % (mod - 1));
        for(int i = 0; i < n; i += h){
            ll w = 1;
            for(int j = i; j < i + hh; j++){
                ll x = a[j], y = w * a[j + hh] % mod;
                a[j] = (x + y) % mod;
                a[j + hh] = (x - y + mod) % mod;
                w = w * wn % mod;
            }
        }
    }
    if(rev == -1){
        int inv = modpow(n, mod - 2);
        for(int i = 0; i < n; i++) a[i] = a[i] * inv % mod;
    }
}
int n, m, x, S, g, id[maxn], num[maxn];
ll A[maxn], B[maxn], C[maxn];
void init(){
    int p = m - 1;
    vector<int> fac;
    for(int i = 2; i * i <= p; i++) if(p % i == 0){
        while(p % i == 0) p /= i;
        fac.push_back(i);
    }
    if(p > 1) fac.push_back(p);
    int s = fac.size();
    for(g = 2;; g++){
        int flag = 1;
        for(int i = 0; i < s; i++)
            if(modpow(g, (m - 1) / fac[i], m) == 1){flag = 0; break;}
        if(flag) break;
    }
    for(int i = 0, pw = 1; i < m - 1; i++, pw = pw * g % m)
        num[id[pw] = i] = pw;
}
int main(){
    scanf("%d%d%d%d", &n, &m, &x, &S);
    init();
    for(int i = 0; i < S; i++){
        int t; scanf("%d", &t);
        if(t % m > 0) ++A[id[t % m]];
    }
    int len = 1;
    while(len < 2 * m - 1) len <<= 1;
    for(int flag = 1; n; n >>= 1){
        if(n & 1){
            if(!flag){
                memcpy(C, A, sizeof(C));
                ntt(C, len, 1), ntt(B, len, 1);
                for(int i = 0; i < len; i++) B[i] = B[i] * C[i] % mod;
                ntt(B, len, -1);
                for(int i = m - 1; i < len; i++)
                    (B[i % (m - 1)] += B[i]) %= mod, B[i] = 0;
            } else memcpy(B, A, sizeof(B)), flag = 0;
        }
        ntt(A, len, 1);
        for(int i = 0; i < len; i++) A[i] = A[i] * A[i] % mod;
        ntt(A, len, -1);
        for(int i = m - 1; i < len; i++)
            (A[i % (m - 1)] += A[i]) %= mod, A[i] = 0;
    }
    printf("%lld\n", B[id[x]]);
    return 0;
}

FFT与字符串匹配

例题 洛谷P4173 残缺的字符串

原题链接
考虑定义字符匹配函数 match(x,y)=(xy)2 m a t c h ( x , y ) = ( x − y ) 2 ,这样只有当两个字符相等时函数值才为0。于是可以考虑类似的定义字符串匹配函数:

M(A,B)=i=0lengthmatch(A(i),B(i)) M ( A , B ) = ∑ i = 0 l e n g t h m a t c h ( A ( i ) , B ( i ) )

于是只有当此函数值为0时,才代表两个字符串相匹配。对于这一题,由于有任意字符的出现,我们定义字符匹配函数 match(x,y)=(xy)2xy m a t c h ( x , y ) = ( x − y ) 2 x y ,也就是说原题中字符”*”的值定义为0,其它从1开始递增。
由于我们希望这个算式出现卷积的形式,可以考虑翻转字符串A,于是从B的a位置开始匹配A的匹配函数就变成了:
M(a)=i=0n1(A(n1i)B(i+a))2A(n1i)B(i+a) M ( a ) = ∑ i = 0 n − 1 ( A ( n − 1 − i ) − B ( i + a ) ) 2 A ( n − 1 − i ) B ( i + a )

=i=0n1A(n1i)3B(i+a)+i=0n1A(n1i)B(i+a)32i=0n1A(n1i)2B(i+a)2 = ∑ i = 0 n − 1 A ( n − 1 − i ) 3 B ( i + a ) + ∑ i = 0 n − 1 A ( n − 1 − i ) B ( i + a ) 3 − 2 ∑ i = 0 n − 1 A ( n − 1 − i ) 2 B ( i + a ) 2

然后就变成了三个卷积,加一下,如果卷积的第k个位置上为0,那也就代表着B的k-n+1位置与A匹配(注意这里下标都从0开始!),复杂度 O(nlogn) O ( n l o g n ) ,常数略大。

#include <bits/stdc++.h>
using namespace std;
typedef long long ll;

const int maxn = 1 << 20, mod = 998244353, G = 3;
int id[128], A[maxn], B[maxn], ans[maxn], n, m, ppos;
char sa[maxn], sb[maxn], prt[10000000];
void print(int x, char c){
    if(x){
        static char sta[10];
        int tp = 0;
        for(; x; x /= 10) sta[tp++] = '0' + x % 10;
        while(tp > 0) prt[ppos++] = sta[--tp];
    } else prt[ppos++] = '0';
    prt[ppos++] = c;
}
inline int modpow(int a, int b){
    int res = 1;
    for(; b; b >>= 1){
        if(b & 1) res = (ll)res * a % mod;
        a = (ll)a * a % mod;
    }
    return res;
}
inline void rader(int *a, int n){
    for(int i = 1, j = n >> 1; i < n - 1; i++){
        if(i < j) swap(a[i], a[j]);
        int k = n >> 1;
        for(; j >= k; k >>= 1) j -= k;
        if(j < k) j += k;
    }
}
void ntt(int *a, int n, int rev){
    rader(a, n);
    for(int h = 2; h <= n; h <<= 1){
        int hh = h >> 1, wn = modpow(G, rev ? mod - 1 - (mod - 1) / h : (mod - 1) / h);
        for(int i = 0; i < n; i += h){
            ll w = 1;
            for(int j = i; j < i + hh; j++){
                int x = a[j], y = w * a[j + hh] % mod;
                a[j] = (x + y) % mod, a[j + hh] = (x - y + mod) % mod;
                w = w * wn % mod;
            }
        }
    }
    if(rev){
        int inv = modpow(n, mod - 2);
        for(int i = 0; i < n; i++) a[i] = (ll)a[i] * inv % mod;
    }
}
int main(){
    for(char i = 'a'; i <= 'z'; i++) id[i] = i - 'a' + 1;
    scanf("%d%d%s%s", &n, &m, sa, sb);
    int len = 1;
    while(len < n + m - 1) len <<= 1;
    for(int i = 0; i < n; i++) A[n - i - 1] = id[sa[i]] * id[sa[i]] * id[sa[i]];
    for(int i = 0; i < m; i++) B[i] = id[sb[i]];
    ntt(A, len, 0), ntt(B, len, 0);
    for(int i = 0; i < len; i++) ans[i] = (ll)A[i] * B[i] % mod;

    memset(A, 0, sizeof(A)), memset(B, 0, sizeof(B));
    for(int i = 0; i < n; i++) A[n - i - 1] = id[sa[i]];
    for(int i = 0; i < m; i++) B[i] = id[sb[i]] * id[sb[i]] * id[sb[i]];
    ntt(A, len, 0), ntt(B, len, 0);
    for(int i = 0; i < len; i++) ans[i] = ((ll)A[i] * B[i] + ans[i]) % mod;

    memset(A, 0, sizeof(A)), memset(B, 0, sizeof(B));
    for(int i = 0; i < n; i++) A[n - i - 1] = id[sa[i]] * id[sa[i]];
    for(int i = 0; i < m; i++) B[i] = id[sb[i]] * id[sb[i]];
    ntt(A, len, 0), ntt(B, len, 0);
    for(int i = 0; i < len; i++) ans[i] = (ans[i] - 2LL * A[i] * B[i] % mod + mod) % mod;

    ntt(ans, len, 1);
    int tp = 0;
    for(int i = n - 1; i < m; i++) if(!ans[i]) A[tp++] = i - n + 2;
    print(tp, '\n');
    for(int i = 0; i < tp; i++) print(A[i], ' ');
    fwrite(prt, 1, ppos, stdout);
    return 0;
}
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值