多项式的系数表示法
考虑多项式 A(x)=∑i=0naixi A ( x ) = ∑ i = 0 n a i x i ,其中 {a0,a1,…,an} { a 0 , a 1 , … , a n } 被称为多项式 A(x) A ( x ) 的系数向量。每个多项式都有唯一的系数向量,每个系数向量都对应唯一的多项式。
多项式的点值表示法
我们可以把多项式 A(x) A ( x ) 看做是一个 n n 次函数,我们可以取个不同的值 b0,b1,⋯,bn b 0 , b 1 , ⋯ , b n 带入分别求出 n+1 n + 1 个多项式的值 c0,c1,⋯,cn c 0 , c 1 , ⋯ , c n 。可以看出,从系数表示法到点值表示法是唯一的,而点值表示法在系数未知的时候可以看做是一个 n+1 n + 1 元一次方程组,可以解出唯一系数。因此点值表示与多项式也一一对应。
多项式乘法
令
C(x)=A(x)B(x)=∑i=0n∑j=0maibjxi+j
C
(
x
)
=
A
(
x
)
B
(
x
)
=
∑
i
=
0
n
∑
j
=
0
m
a
i
b
j
x
i
+
j
,其中
A(x),B(x)
A
(
x
)
,
B
(
x
)
分别是
n,m
n
,
m
次多项式,
A(x),B(x)
A
(
x
)
,
B
(
x
)
的系数向量是
a→,b→
a
→
,
b
→
。
容易发现,用系数表示法使两个向量相乘是
O(n2)
O
(
n
2
)
的复杂度,那如何才能优化呢?
考虑两个点值表示的多项式相乘,易发现此时只需要把两个多项式的对应点值相乘即可,复杂度为
O(n)
O
(
n
)
。但是如何把系数表示转化为点值表示,再转化回来呢?。
如果我们选取
n+1
n
+
1
个值暴力代入,复杂度仍然为
O(n2)
O
(
n
2
)
,甚至转化回来的时候会用到
O(n3)
O
(
n
3
)
的高斯消元,难道点值表示就没有任何可取之处了吗?
因此,一种算法叫做“快速傅里叶变换”诞生了,它可以在
O(nlogn)
O
(
n
l
o
g
n
)
的时间内完成上述两部转化。
快速傅里叶变换(FFT)
单位根
若
xn=1
x
n
=
1
,则
x
x
被称为次单位根。
n
n
次单位根共有个,分别形如
e2kπin,0≤k<n,k∈Z
e
2
k
π
i
n
,
0
≤
k
<
n
,
k
∈
Z
,注意这里的
i
i
是虚数单位。为什么呢?
倒数第二步使用了欧拉公式 exi=cos(x)+i⋅sin(x) e x i = c o s ( x ) + i · s i n ( x ) ,因此我们也可以得到 e2kπin=cos(2kπn)+i⋅sin(2kπn) e 2 k π i n = c o s ( 2 k π n ) + i · s i n ( 2 k π n ) ,这样我们就可以用平常的复数表示去计算单位根了。
为了方便,我们记 ωn=e2πin ω n = e 2 π i n ,则 n n 个单位根分别为。
单位根的性质
在讨论性质时,均假定 n n 为偶数。
引理1
单位根具有对称性,即。这个定理是比较好证明的,因为有
引理2
这个利用性质也是很好证明的。
FFT算法
上面我们说了那么多,究竟是要干什么呢?没错!把单位根当做数值带入多项式,求出多项式的点值表示。但是到此为止,我们的复杂度还是
O(n2)
O
(
n
2
)
的,甚至由于涉及到复数运算,常数只会比原来更大。于是我们要好好利用单位根的性质进行简化。接下来假设
n
n
是2的整数次幂。
考虑关于单位根的次多项式
A(ωkn)
A
(
ω
n
k
)
,先暴力计算(注意这里的
i
i
不是虚数啦):
FFT接下来做的事情是把这个东西按照奇偶项分类:
利用单位根性质化简,可以得到:
于是,我们惊奇的发现,按照奇偶项分类之后,我们把有 n n 个要带入的值划分成了2个需要带入个值的子问题!再加上引理1,我们可以总结出分治合并过程:
于是,快速傅里叶变换可以在 T(n)=2T(n2)+O(n)=O(nlogn) T ( n ) = 2 T ( n 2 ) + O ( n ) = O ( n l o g n ) 的时间内完成!
傅里叶逆变换
我们已经可以在
O(nlogn)
O
(
n
l
o
g
n
)
的时间内把多项式的系数表示转化为点值表示,但是如何把点值表示转化为系数表示呢?
也就是我们需要解出一个
n
n
元一次方程组,考虑把它化为矩阵形式:
其中 {a0,a1,⋯,an−1} { a 0 , a 1 , ⋯ , a n − 1 } 是多项式 A A 的系数向量。如果我们能够求出来左边矩阵的逆矩阵,问题就会好办很多。考虑如下的矩阵:
也就是原矩阵中每个数取倒数,考虑这两个矩阵相乘,设原矩阵为 P P ,上面的矩阵为。
于是可以分类讨论了。
- 若 i=j i = j ,则 (PQ)i,j=n ( P Q ) i , j = n 。
- 若 i≠j i ≠ j ,则原式可以看做是等比数列的形式,原式 =ωi−jn⋅ωnn−1ωn−1=0 = ω n i − j · ω n n − 1 ω n − 1 = 0
于是我们得到了一个结论:两个矩阵的乘积除主对角线为
n
n
,其它位置全部为0.这可以看做是倍的单位矩阵,也就是说,我们把
Q
Q
矩阵和右边的点值向量相乘,就可以得到系数向量。但是这样的复杂度仍然是的。
注意到
ω−kn
ω
n
−
k
实际上仍然是
n
n
次单位根!证明:
因此它仍然满足单位根的所有性质!于是我们可以把FFT时的 ωkn ω n k 用 ω−kn ω n − k 代替,再跑一边FFT,得出来的 n n 个数字除以就是原多项式的系数向量!
于是,傅里叶逆变换也可以在 O(nlogn) O ( n l o g n ) 的时间内完成。
到此时,FFT的递归实现应该也比较好理解了。
注意上面假设的都是n为2的整数次幂!如果不足需要补齐!
FFT的迭代实现
(这里容许我盗一下图……)
我们来观察一下最后一步时所有数字的顺序。考虑把所有二进制串反过来,比如1000变为0001,我们会发现最后一步时fft的顺序就是从0到
n−1
n
−
1
!(最后两个似乎画反了……)也就是说,原串中第
i
i
个数到fft的最后一步时就变成了第个数,其中
rev
r
e
v
函数表示翻转一个数的二进制表示。只要我们按照这个排好序,一步一步合并上去就行了!
于是我们从头到尾扫一遍数组,假设当前扫到第
i
i
个数,只要,我们就可以交换
rev(i)和i
r
e
v
(
i
)
和
i
的值,这样最后得到的数组就是fft最后一步的数组!然后就可以很方便地迭代实现fft了!
const int maxn = 1 << 18;
const long double PI = (long double)3.14159265358979323846;
struct Complex{
long double r, i;
Complex(){r = i = 0;}
Complex(long double a, long double b){r = a, i = b;}
Complex operator+(const Complex &c) const
{return Complex(r + c.r, i + c.i);}
Complex operator-(const Complex &c) const
{return Complex(r - c.r, i - c.i);}
Complex operator*(const Complex &c) const
{return Complex(r * c.r - i * c.i, i * c.r + r * c.i);}
} A[maxn];
void rader(Complex *a, int n){//倒位序
for(int i = 1, j = n >> 1; i < n - 1; i++){
if(i < j) swap(a[i], a[j]);//j=rev(i)
int k = n >> 1;
for(; j >= k; k >>= 1) j -= k;
if(j < k) j += k;//反向二进制加法
}
}
void fft(Complex *a, int n, int rev){
rader(a, n);
for(int h = 2; h <= n; h <<= 1){
int hh = h >> 1;
Complex wn = Complex(cosl(PI / hh), rev * sinl(PI / hh));
for(int i = 0; i < n; i += h){
Complex *ta = a + i, *tb = a + i + hh, w = Complex(1, 0);
for(int j = 0; j < hh; ++j, ++ta, ++tb){
Complex x = *ta, y = w * *tb;
*ta = x + y, *tb = x - y, w = w * wn;
}
}
}
if(rev == -1) for(int i = 0; i < n; i++)
a[i] = a[i] * Complex(1.0 / n, 0);
}
一个小优化:正常来说我们都是进行两边DFT,然后点值乘法,再IDFT,但实际上在用FFT算卷积的时候可以去掉一个DFT。比如计算a和b的卷积,我们把需要进行FFT的复数数组的实数部分设置为a,虚数部分设置为b,然后DFT一次,计算自己的平方卷积,再IDFT出来,虚数部分结果除以2就是原来的答案。
快速数论变换(NTT)
FFT已经可以在
O(nlogn)
O
(
n
l
o
g
n
)
的时间内完成多项式点值和系数表示之间的转换,但是在OI中,我们经常要求的是对于某个数求模的结果,这样FFT的精度显然不够了。
考虑在模运算下定义单位根。设模数为质数
p
p
,那么它的原根实际上和
ωn
ω
n
等价。为什么呢?考虑单位根的几个性质:
1.
n
n
个单位根互不相等,根据原根定义,原根的0次幂到次幂都不相等,上面那
n
n
个值自然不相等。
2.单位根的次幂等于1,这个根据费马小定理,任意与
p
p
互质正整数的次幂都为1,因此对于上面的也成立。
3.单位根的对称性。证明:
证毕。
4.单位根的引理2,这个是显然成立的。
于是我们也可以在 O(nlogn) O ( n l o g n ) 的时间内计算出模运算意义下的卷积了(但似乎常数不小啊……)
注意, p−1 p − 1 必须是 n n 的倍数!如!!
先附上找质数原根的代码(998244353和1004535809( =221×479 = 2 21 × 479 )的原根都是3)
#include <bits/stdc++.h>
using namespace std;
typedef long long ll;
ll modmul(ll a, ll b, ll mod){
ll res = 0;
for(; b; b >>= 1){
if(b & 1) res = (res + a) % mod;
a = (a + a) % mod;
}
return res;
}
ll modpow(ll a, ll b, ll mod){
ll res = 1;
for(; b; b >>= 1){
if(b & 1) res = modmul(res, a, mod) % mod;
a = modmul(a, a, mod) % mod;
}
return res;
}
vector<ll> vec;
int main(){
ll mod;
while(~scanf("%lld", &mod)){
vec.clear();
ll p = mod - 1;
for(ll i = 2; i * i <= p; i++) if(p % i == 0) {
vec.push_back(i);
while(p % i == 0) p /= i;
}
if(p > 1) vec.push_back(p);
int sz = vec.size();
for(int i = 2;; i++){
int flag = 1;
for(int j = 0; j < sz; j++)
if(modpow(i, (mod - 1) / vec[j], mod) == 1){flag = 0; break;}
if(flag == 1){printf("%d\n", i); break;}
}
}
return 0;
}
再附上NTT的板子~
typedef long long ll;
const int mod = 998244353, G = 3;
ll modpow(ll a, int b){
ll res = 1;
for(; b; b >>= 1){
if(b & 1) res = res * a % mod;
a = a * a % mod;
}
return res;
}
void rader(ll *a, int n){
for(int i = 1, j = n >> 1; i < n - 1; i++){
if(i < j) swap(a[i], a[j]);
int k = n >> 1;
for(; j >= k; k >>= 1) j -= k;
if(j < k) j += k;
}
}
void NTT(ll *a, int n, int rev){
rader(a, n);
for(int h = 2; h <= n; h <<= 1){
int hh = h >> 1, wn = modpow(G, rev ? mod - 1 - (mod - 1) / h : (mod - 1) / h);
for(int i = 0; i < n; i += h){
ll w = 1;
for(int j = i; j < i + hh; j++){
int x = a[j], y = w * a[j + hh] % mod;
a[j] = (x + y) % mod;
a[j + hh] = (x - y + mod) % mod;
w = w * wn % mod;
}
}
}
if(rev){
int inv = modpow(n, mod - 2);
for(int i = 0; i < n; i++) a[i] = a[i] * inv % mod;
}
}
此时我们就可以愉快的做题啦!
例题
BZOJ4555: [Tjoi2016&Heoi2016]求和
原题链接
题意:求如下函数的值:
其中, S(i,j) S ( i , j ) 指第二类斯特林数。
首先发现,对于 j>i,S(i,j)=0 j > i , S ( i , j ) = 0 ,可以考虑扩大 j j 的取值范围,使得它与无关。
再考虑暴力展开第二类斯特林数(如果不造第二类斯特林数展开的,戳 这里,翻到靠近下面的部分,有公式+证明),则
右边的 ∑ni=0ki ∑ i = 0 n k i 实际上可以看做是等比数列求和,对于 k=0,1 k = 0 , 1 的情况特判一下,这样右边就成了一个卷积的形式,直接NTT就行了,复杂度 O(nlogn) O ( n l o g n ) 。
#include <bits/stdc++.h>
using namespace std;
typedef long long ll;
const int maxn = 100005, mod = 998244353, G = 3;
int modpow(ll a, int b){
ll res = 1;
for(; b; b >>= 1){
if(b & 1) res = res * a % mod;
a = a * a % mod;
}
return res;
}
void rader(ll *a, int n){
for(int i = 1, j = n >> 1; i < n - 1; i++){
if(i < j) swap(a[i], a[j]);
int k = n >> 1;
for(; j >= k; k >>= 1) j -= k;
if(j < k) j += k;
}
}
void NTT(ll *a, int n, int rev){
rader(a, n);
for(int h = 2; h <= n; h <<= 1){
int hh = h >> 1, wn = modpow(G, rev ? mod - 1 - (mod - 1) / h : (mod - 1) / h);
for(int i = 0; i < n; i += h){
ll w = 1;
for(int j = i; j < i + hh; j++){
int x = a[j], y = w * a[j + hh] % mod;
a[j] = (x + y) % mod;
a[j + hh] = (x - y + mod) % mod;
w = w * wn % mod;
}
}
}
if(rev){
int inv = modpow(n, mod - 2);
for(int i = 0; i < n; i++) a[i] = a[i] * inv % mod;
}
}
ll fact[maxn], revf[maxn], rev[maxn], A[1 << 18], B[1 << 18];
int main(){
int n; scanf("%d", &n);
rev[1] = fact[0] = revf[0] = 1;
for(int i = 1; i <= n; i++){
if(i > 1) rev[i] = mod - (ll)mod / i * rev[mod % i] % mod;
revf[i] = revf[i - 1] * rev[i] % mod;
fact[i] = fact[i - 1] * i % mod;
}
int tn = 1;
while(tn < 2 * n + 1) tn <<= 1;
for(int i = 0; i <= n; i++){
A[i] = i & 1 ? mod - revf[i] : revf[i];
if(i > 0) B[i] = (i > 1 ? (modpow(i, n + 1) - 1) * rev[i - 1] % mod : n + 1) * revf[i] % mod;
else B[i] = 1;
}
NTT(A, tn, 0), NTT(B, tn, 0);
for(int i = 0; i < tn; i++) A[i] = A[i] * B[i] % mod;
NTT(A, tn, 1);
ll res = 0;
for(int i = 1, j = 0; j <= n; j++){
res = (fact[j] * i % mod * A[j] + res) % mod;
i = i * 2 % mod;
}
printf("%lld\n", res);
return 0;
}
BZOJ4836: [Lydsy1704月赛]二元运算
原题链接
像这种区别于i,j的卷积可以使用CDQ分治+NTT处理。先考虑把一整段剖成左右两块,分别递归处理,然后再使用NTT计算左边的x对右边的y的贡献。但是这道题右边x对左边y的贡献是减法卷积,我们可以把左边的多项式翻转,再求卷积,理论复杂度为
O(nlog2n)
O
(
n
l
o
g
2
n
)
,但似乎常数超级大,而且明明在本机跑得比别人快2倍,在BZOJ上却T掉……
不管了,假装自己过了
还是放上我自己常数巨大的代码吧……
#include <bits/stdc++.h>
using namespace std;
typedef long long ll;
const int maxn = 1 << 17;
const double PI = 3.1415926535898;
int A[maxn], B[maxn], rev[maxn], n, m, Q, T;
ll C[maxn];
struct Complex{
double r, i;
Complex(){r = i = 0.0;}
Complex(double a, double b){r = a, i = b;}
Complex operator+(const Complex &c) const {
return Complex(r + c.r, i + c.i);
}
Complex operator-(const Complex &c) const {
return Complex(r - c.r, i - c.i);
}
Complex operator*(const Complex &c) const {
return Complex(r * c.r - i * c.i, r * c.i + i * c.r);
}
} AA[maxn], BB[maxn], R[maxn];
void FFT(Complex *a, int n, int r){
for(int i = 0; i < n; i++) if(i < rev[i]) swap(a[i], a[rev[i]]);
for(int h = 2; h <= n; h <<= 1){
int hh = h >> 1;
Complex wn = Complex(cos(PI / hh), sin(PI / hh));
if(r) wn.i = -wn.i;
Complex *ta = a, *tb = a + hh;
for(int i = 0; i < n; i += h){
Complex w = Complex(1, 0);
for(int j = 0; j < hh; ++j, ++ta, ++tb){
Complex x = *ta, y = w * *tb;
*ta = x + y, *tb = x - y;
w = w * wn;
}
ta += hh, tb += hh;
}
}
if(r) for(int i = 0; i < n; i++) a[i].r = a[i].r / n;
}
void mul(Complex *a, Complex *b, int n){
if(n <= 32){
for(int i = 0; i < n; i++) R[i] = Complex(0, 0);
for(int i = 0; i < n >> 1; i++)
for(int j = 0; j < n >> 1; j++)
R[i + j] = R[i + j] + a[i] * b[j];
for(int i = 0; i < n; i++) a[i] = R[i];
} else {
FFT(a, n, 0), FFT(b, n, 0);
for(int i = 0; i < n; i++) a[i] = a[i] * b[i];
FFT(a, n, 1);
}
}
void cdq(int l, int r){
if(l == r - 1){C[0] += A[l] * B[l]; return;}
int mid = (l + r) >> 1, hlen = r - l, len = hlen << 1;
int t = __builtin_ctz(hlen);
for(int i = 0; i < len; i++) rev[i] = rev[i >> 1] >> 1 | (i & 1) << t;
for(int i = 0; i < len; i++) AA[i] = BB[i] = Complex(0, 0);
for(int i = l; i < mid; i++) AA[i - l].r = A[i];
for(int i = mid; i < r; i++) BB[i - mid].r = B[i];
mul(AA, BB, len);
for(int i = 0; i < len; i++) C[i + l + mid] += (int)(AA[i].r + 0.1);
for(int i = 0; i < len; i++) AA[i] = BB[i] = Complex(0, 0);
for(int i = mid; i < r; i++) AA[i - mid].r = A[i];
for(int i = l; i < mid; i++) BB[i - l].r = B[mid - i + l - 1];
mul(AA, BB, len);
for(int i = 0; i < len; i++) C[i + 1] += (int)(AA[i].r + 0.1);
cdq(l, mid), cdq(mid, r);
}
const int maxr = 10000000;
char str[maxr], prt[maxr]; int rpos, ppos;
char readc(){
if(!rpos) fread(str, 1, maxr, stdin);
char c = str[rpos++];
if(rpos == maxr) rpos = 0;
return c;
}
int read(){
int x; char c;
while((c = readc()) < '0' || c > '9');
x = c - '0';
while((c = readc()) >= '0' && c <= '9') x = x * 10 + c - '0';
return x;
}
void print(ll x){
if(x){
static char sta[20];
int tp = 0;
for(; x; x /= 10) sta[tp++] = x % 10 + '0';
while(tp > 0) prt[ppos++] = sta[--tp];
} else prt[ppos++] = '0';
prt[ppos++] = '\n';
}
int main(){
for(T = read(); T--;){
n = read(), m = read(), Q = read();
int mx = 0, N = 1;
memset(A, 0, sizeof(A));
memset(B, 0, sizeof(B));
memset(C, 0, sizeof(C));
for(int i = 0; i < n; i++){
int t = read();
mx = max(mx, t);
++A[t];
}
for(int i = 0; i < m; i++){
int t = read();
mx = max(mx, t);
++B[t];
}
while(N <= mx) N <<= 1;
cdq(0, N);
while(Q--) print(C[read()]);
}
fwrite(prt, 1, ppos, stdout);
return 0;
}
例题BZOJ4827: [Hnoi2017]礼物
原题链接
这道题似乎比较水啊,考虑如何算出确定两个装饰的亮度时不同旋转位置的差异值。我们把
∑ni=1(xi−yi)2
∑
i
=
1
n
(
x
i
−
y
i
)
2
变成
∑ni=1xi+∑ni=1yi−2∑ni=1xiyi
∑
i
=
1
n
x
i
+
∑
i
=
1
n
y
i
−
2
∑
i
=
1
n
x
i
y
i
,会发现前面两个是常数,后面一个是乘法。考虑把乘法化成卷积的形式,把一个手环上所有装饰的信息复制一遍接在后面(对没错就跟普通处理环的方法一样),然后翻转另一个手环的亮度信息,这样做一个卷积,就可以得到旋转不同的角度得到的差异值了。
再考虑如何计算最小差异值。如果我们确定一个亮度时算出了不同位置的差异值,并且找出了最小值所在的位置,那么无论我把一个手环的亮度整体加上多少,卷积后取最小值的位置必然不会改变,因为整体加
k
k
可以看做是所有角度上的值都加上了,然后前面两个常数的变化其实可以暴力算,复杂度
O(nlogn+nm)
O
(
n
l
o
g
n
+
n
m
)
就解决了。
其实再深入一点,差异值关于亮度整体的变化值是一个凹函数,因此可以三分优化到
O(n(logn+logm))
O
(
n
(
l
o
g
n
+
l
o
g
m
)
)
。
#include <bits/stdc++.h>
using namespace std;
typedef long long ll;
const int maxn = 1 << 18;
const double PI = 3.14159265358979323846;
int A[maxn], B[maxn], n, m;
struct Complex{
double r, i;
Complex(){r = i = 0.0;}
Complex(double a, double b){r = a, i = b;}
Complex operator+(const Complex &c) const {
return Complex(r + c.r, i + c.i);
}
Complex operator-(const Complex &c) const {
return Complex(r - c.r, i - c.i);
}
Complex operator*(const Complex &c) const {
return Complex(r * c.r - i * c.i, i * c.r + r * c.i);
}
} AA[maxn];
void rader(Complex *a, int n){
for(int i = 1, j = n >> 1; i < n - 1; i++){
if(i < j) swap(a[i], a[j]);
int k = n >> 1;
for(; j >= k; k >>= 1) j -= k;
if(j < k) j += k;
}
}
void fft(Complex *a, int n, int rev){
rader(a, n);
for(int h = 2; h <= n; h <<= 1){
int hh = h >> 1;
Complex wn = Complex(cos(PI / hh), rev * sin(PI / hh));
for(int i = 0; i < n; i += h){
Complex w = Complex(1, 0);
for(int j = i; j < i + hh; j++){
Complex x = a[j], y = w * a[j + hh];
a[j] = x + y, a[j + hh] = x - y;
w = w * wn;
}
}
}
if(rev == -1) for(int i = 0; i < n; i++)
a[i] = a[i] * Complex(1.0 / n, 0);
}
int calc(int i, int mn, int ss){
int sum = 0;
for(int j = 0; j < n; j++) sum += (A[j] + i) * (A[j] + i);
return sum - 2 * (mn + i * ss);
}
int main(){
scanf("%d%d", &n, &m);
for(int i = 0; i < n; i++){
scanf("%d", A + i);
A[i] = A[i + n] = A[i];
AA[i].r = AA[i + n].r = A[i];
}
for(int i = 0; i < n; i++){
scanf("%d", B + i);
AA[n - i - 1].i = B[i];
}
int len = 1;
while(len < 3 * n) len <<= 1;
fft(AA, len, 1);
for(int i = 0; i < len; i++) AA[i] = AA[i] * AA[i];
fft(AA, len, -1);
int mn = INT_MIN, res = INT_MAX, ini = 0, ss = 0;
for(int i = n - 1; i < 2 * n - 1; i++){
mn = max(mn, int(AA[i].i / 2 + 0.5));
//printf("%d %d\n", i, int(AA[i].i / 2 - 0.5));
}
for(int i = 0; i < n; i++) ini += B[i] * B[i], ss += B[i];
int l = -m, r = m;
while(l + 3 <= r){
int len = (r - l + 1) / 3;
int m1 = l + len, m2 = r - len;
if(calc(m1, mn, ss) < calc(m2, mn, ss)) r = m2;
else l = m1;
}
for(int i = l; i <= r; i++)
res = min(res, ini + calc(i, mn, ss));
printf("%d\n", res);
return 0;
}
上述代码便使用了以前说的FFT计算时的小优化,两次FFT即可计算出卷积。
例题 洛谷P4491 [HAOI2018]染色
原题链接
其实这道题并不难,主要是推公式的时候一定要细心细心再细心!!
首先我们考虑枚举有多少种颜色恰好出现了s次,再令
g(i,j)
g
(
i
,
j
)
表示
i
i
种颜色填入个格子且没有颜色出现s次的方案数。然后对于原题,我们会发现最多只会有
⌊ns⌋
⌊
n
s
⌋
种颜色出现恰好s次,于是令
N=min(⌊ns⌋,m)
N
=
m
i
n
(
⌊
n
s
⌋
,
m
)
,可以得到:
再考虑如何计算g。我们可以枚举有恰好多少种颜色出现了s次,然后容斥一下:
带入到原式中,并且把组合数展开:
我们把类似的项合并,并消去一些相同的项,可以得到:
发现出现了较多的 i+j i + j ,并且根据它们的取值范围, i+j≤N i + j ≤ N ,这就很像卷积了!来试试看枚举 i+j i + j ……
好了,到此为止,右边已经是一个卷积的形式了,直接NTT即可,复杂度 O(nlogn) O ( n l o g n ) 。
#include <bits/stdc++.h>
using namespace std;
typedef long long ll;
const int maxn = 100005, maxt = 1 << 18, mod = 1004535809, G = 3;
const int maxmn = 10000005;
ll fact[maxmn], rev[maxmn], A[maxt], B[maxt], val[maxn], n, m, N, S;
ll modpow(ll a, int b){
ll res = 1;
for(; b; b >>= 1){
if(b & 1) res = res * a % mod;
a = a * a % mod;
}
return res;
}
void rader(ll *a, int n){
for(int i = 1, j = n >> 1; i < n - 1; i++){
if(i < j) swap(a[i], a[j]);
int k = n >> 1;
for(; j >= k; k >>= 1) j -= k;
if(j < k) j += k;
}
}
void NTT(ll *a, int n, int rev){
rader(a, n);
for(int h = 2; h <= n; h <<= 1){
int hh = h >> 1, wn = modpow(G, rev ? mod - 1 - (mod - 1) / h : (mod - 1) / h);
for(int i = 0; i < n; i += h){
ll w = 1;
for(int j = i; j < i + hh; j++){
int x = a[j], y = w * a[j + hh] % mod;
a[j] = (x + y) % mod;
a[j + hh] = (x - y + mod) % mod;
w = w * wn % mod;
}
}
}
if(rev){
int inv = modpow(n, mod - 2);
for(int i = 0; i < n; i++) a[i] = a[i] * inv % mod;
}
}
int main(){
scanf("%d%d%d", &n, &m, &S);
for(int i = 0; i <= m; i++) scanf("%lld", val + i);
N = min(n / S, m);
fact[0] = rev[0] = 1;
int mn = max(n, m);
for(int i = 1; i <= mn; i++)
fact[i] = fact[i - 1] * i % mod;
rev[mn] = modpow(fact[mn], mod - 2);
for(int i = mn - 1; i > 0; i--)
rev[i] = rev[i + 1] * (i + 1) % mod;
for(int i = 0; i <= N; i++){
A[i] = val[i] * rev[i] % mod;
B[i] = i & 1 ? mod - rev[i] : rev[i];
}
int tn = 1;
while(tn < 2 * (N + 1)) tn <<= 1;
NTT(A, tn, 0), NTT(B, tn, 0);
for(int i = 0; i < tn; i++) A[i] = A[i] * B[i] % mod;
NTT(A, tn, 1);
ll sf = 1, res = 0;
for(int i = 1; i <= S; i++) sf = sf * i % mod;
for(int i = 0; i <= N; i++){
ll t = fact[m] * fact[n] % mod * modpow(m - i, n - i * S) % mod;
t = t * rev[m - i] % mod * rev[n - i * S] % mod * modpow(sf, mod - 1 - i) % mod;
res = (res + t * A[i]) % mod;
}
printf("%lld\n", res);
return 0;
}
例题BZOJ3992[SDOI2015]序列统计
原题链接
观察到n很大,如果用矩阵乘法的话m又过大了,考虑用倍增(快速幂)解决。
考虑dp。设
f[i][j]
f
[
i
]
[
j
]
表示当前已经选了i个数,得到的乘积模m为j的选取方案数。如果我们能够快速合并
f[a],f[b]
f
[
a
]
,
f
[
b
]
得到
f[a+b]
f
[
a
+
b
]
的值,这道题就解决了。由于
观察到这个东西和卷积长得很像,但是卷积是加,它是乘法。于是我们可以利用题目条件把乘法转化为加法。
令m的原根为g,则0到m-1都可以表示成g的幂次。定义离散对数 lg(i)使glg(i)≡i(mod m) l g ( i ) 使 g l g ( i ) ≡ i ( m o d m ) ,并且把 f[i][j] f [ i ] [ j ] 的定义改为表示当前已经选了i个数,得到的乘积模m为 gj g j ,那么上述状态转移方程就可以变为:
答案就是 f[n][lg(x)] f [ n ] [ l g ( x ) ] 。中间算卷积时的模运算可以等卷完了在处理,总复杂度 O(mlognlogm) O ( m l o g n l o g m ) 。
#include <bits/stdc++.h>
using namespace std;
typedef long long ll;
const int maxn = 1 << 14, mod = 1004535809, G = 3;
ll modpow(ll a, ll b, ll p = mod){
ll res = 1;
for(; b; b >>= 1){
if(b & 1) res = res * a % p;
a = a * a % p;
}
return res;
}
void rader(ll *a, int n){
for(int i = 1, j = n >> 1; i < n - 1; i++){
if(i < j) swap(a[i], a[j]);
int k = n >> 1;
for(; j >= k; k >>= 1) j -= k;
if(j < k) j += k;
}
}
void ntt(ll *a, int n, int rev){
rader(a, n);
for(int h = 2; h <= n; h <<= 1){
int hh = h >> 1, wn = modpow(G, (mod - 1 + rev * (mod - 1) / h) % (mod - 1));
for(int i = 0; i < n; i += h){
ll w = 1;
for(int j = i; j < i + hh; j++){
ll x = a[j], y = w * a[j + hh] % mod;
a[j] = (x + y) % mod;
a[j + hh] = (x - y + mod) % mod;
w = w * wn % mod;
}
}
}
if(rev == -1){
int inv = modpow(n, mod - 2);
for(int i = 0; i < n; i++) a[i] = a[i] * inv % mod;
}
}
int n, m, x, S, g, id[maxn], num[maxn];
ll A[maxn], B[maxn], C[maxn];
void init(){
int p = m - 1;
vector<int> fac;
for(int i = 2; i * i <= p; i++) if(p % i == 0){
while(p % i == 0) p /= i;
fac.push_back(i);
}
if(p > 1) fac.push_back(p);
int s = fac.size();
for(g = 2;; g++){
int flag = 1;
for(int i = 0; i < s; i++)
if(modpow(g, (m - 1) / fac[i], m) == 1){flag = 0; break;}
if(flag) break;
}
for(int i = 0, pw = 1; i < m - 1; i++, pw = pw * g % m)
num[id[pw] = i] = pw;
}
int main(){
scanf("%d%d%d%d", &n, &m, &x, &S);
init();
for(int i = 0; i < S; i++){
int t; scanf("%d", &t);
if(t % m > 0) ++A[id[t % m]];
}
int len = 1;
while(len < 2 * m - 1) len <<= 1;
for(int flag = 1; n; n >>= 1){
if(n & 1){
if(!flag){
memcpy(C, A, sizeof(C));
ntt(C, len, 1), ntt(B, len, 1);
for(int i = 0; i < len; i++) B[i] = B[i] * C[i] % mod;
ntt(B, len, -1);
for(int i = m - 1; i < len; i++)
(B[i % (m - 1)] += B[i]) %= mod, B[i] = 0;
} else memcpy(B, A, sizeof(B)), flag = 0;
}
ntt(A, len, 1);
for(int i = 0; i < len; i++) A[i] = A[i] * A[i] % mod;
ntt(A, len, -1);
for(int i = m - 1; i < len; i++)
(A[i % (m - 1)] += A[i]) %= mod, A[i] = 0;
}
printf("%lld\n", B[id[x]]);
return 0;
}
FFT与字符串匹配
例题 洛谷P4173 残缺的字符串
原题链接
考虑定义字符匹配函数
match(x,y)=(x−y)2
m
a
t
c
h
(
x
,
y
)
=
(
x
−
y
)
2
,这样只有当两个字符相等时函数值才为0。于是可以考虑类似的定义字符串匹配函数:
于是只有当此函数值为0时,才代表两个字符串相匹配。对于这一题,由于有任意字符的出现,我们定义字符匹配函数 match(x,y)=(x−y)2xy m a t c h ( x , y ) = ( x − y ) 2 x y ,也就是说原题中字符”*”的值定义为0,其它从1开始递增。
由于我们希望这个算式出现卷积的形式,可以考虑翻转字符串A,于是从B的a位置开始匹配A的匹配函数就变成了:
然后就变成了三个卷积,加一下,如果卷积的第k个位置上为0,那也就代表着B的k-n+1位置与A匹配(注意这里下标都从0开始!),复杂度 O(nlogn) O ( n l o g n ) ,常数略大。
#include <bits/stdc++.h>
using namespace std;
typedef long long ll;
const int maxn = 1 << 20, mod = 998244353, G = 3;
int id[128], A[maxn], B[maxn], ans[maxn], n, m, ppos;
char sa[maxn], sb[maxn], prt[10000000];
void print(int x, char c){
if(x){
static char sta[10];
int tp = 0;
for(; x; x /= 10) sta[tp++] = '0' + x % 10;
while(tp > 0) prt[ppos++] = sta[--tp];
} else prt[ppos++] = '0';
prt[ppos++] = c;
}
inline int modpow(int a, int b){
int res = 1;
for(; b; b >>= 1){
if(b & 1) res = (ll)res * a % mod;
a = (ll)a * a % mod;
}
return res;
}
inline void rader(int *a, int n){
for(int i = 1, j = n >> 1; i < n - 1; i++){
if(i < j) swap(a[i], a[j]);
int k = n >> 1;
for(; j >= k; k >>= 1) j -= k;
if(j < k) j += k;
}
}
void ntt(int *a, int n, int rev){
rader(a, n);
for(int h = 2; h <= n; h <<= 1){
int hh = h >> 1, wn = modpow(G, rev ? mod - 1 - (mod - 1) / h : (mod - 1) / h);
for(int i = 0; i < n; i += h){
ll w = 1;
for(int j = i; j < i + hh; j++){
int x = a[j], y = w * a[j + hh] % mod;
a[j] = (x + y) % mod, a[j + hh] = (x - y + mod) % mod;
w = w * wn % mod;
}
}
}
if(rev){
int inv = modpow(n, mod - 2);
for(int i = 0; i < n; i++) a[i] = (ll)a[i] * inv % mod;
}
}
int main(){
for(char i = 'a'; i <= 'z'; i++) id[i] = i - 'a' + 1;
scanf("%d%d%s%s", &n, &m, sa, sb);
int len = 1;
while(len < n + m - 1) len <<= 1;
for(int i = 0; i < n; i++) A[n - i - 1] = id[sa[i]] * id[sa[i]] * id[sa[i]];
for(int i = 0; i < m; i++) B[i] = id[sb[i]];
ntt(A, len, 0), ntt(B, len, 0);
for(int i = 0; i < len; i++) ans[i] = (ll)A[i] * B[i] % mod;
memset(A, 0, sizeof(A)), memset(B, 0, sizeof(B));
for(int i = 0; i < n; i++) A[n - i - 1] = id[sa[i]];
for(int i = 0; i < m; i++) B[i] = id[sb[i]] * id[sb[i]] * id[sb[i]];
ntt(A, len, 0), ntt(B, len, 0);
for(int i = 0; i < len; i++) ans[i] = ((ll)A[i] * B[i] + ans[i]) % mod;
memset(A, 0, sizeof(A)), memset(B, 0, sizeof(B));
for(int i = 0; i < n; i++) A[n - i - 1] = id[sa[i]] * id[sa[i]];
for(int i = 0; i < m; i++) B[i] = id[sb[i]] * id[sb[i]];
ntt(A, len, 0), ntt(B, len, 0);
for(int i = 0; i < len; i++) ans[i] = (ans[i] - 2LL * A[i] * B[i] % mod + mod) % mod;
ntt(ans, len, 1);
int tp = 0;
for(int i = n - 1; i < m; i++) if(!ans[i]) A[tp++] = i - n + 2;
print(tp, '\n');
for(int i = 0; i < tp; i++) print(A[i], ' ');
fwrite(prt, 1, ppos, stdout);
return 0;
}