多项式算法6:分治 FFT
分治FFT
分治FFT主要是求解以下问题:
给定序列 g 1 ⋯ g n − 1 g_{1} \cdots g_{n-1} g1⋯gn−1,求序列 f 0 ⋯ f n − 1 。 f_{0} \cdots f_{n-1}。 f0⋯fn−1。
其中 f i = ∑ j = 1 i f i − j g j f_{i}= \sum_{j=1}^{i}f_{i-j}g_{j} fi=∑j=1ifi−jgj,边界为 f 0 = 1 f_{0} = 1 f0=1(边界不一定为1,看题目要求)。
模板题在此。
我们主要有两种求解方法,一种是用生成函数构造然后用多项式求逆的方法求解,另一种是CDQ分治法。
第一种解法(多项式求逆)
我们先设
g
0
=
0
g_0=0
g0=0,然后构造生成函数得:
F
(
x
)
=
∑
i
=
0
∞
f
i
x
i
F(x)=\sum^{\infin}_{i=0}f_i x^i
F(x)=i=0∑∞fixi
G
(
x
)
=
∑
i
=
0
∞
g
i
x
i
G(x)=\sum^{\infin}_{i=0}g_i x^i
G(x)=i=0∑∞gixi相乘得:
F
(
x
)
×
G
(
x
)
=
∑
i
=
0
∞
∑
j
=
0
∞
f
i
×
g
j
×
x
i
+
j
F(x) \times G(x) = \sum^{\infin}_{i=0} \sum^{\infin}_{j=0} f_i \times g_j \times x^{i+j}
F(x)×G(x)=i=0∑∞j=0∑∞fi×gj×xi+j令
k
=
i
+
j
k=i+j
k=i+j,可得:
F
(
x
)
×
G
(
x
)
=
∑
k
=
0
∞
(
∑
j
=
0
k
f
k
−
j
×
g
j
)
x
k
F(x) \times G(x) = \sum^{\infin}_{k=0}( \sum^{k}_{j=0} f_{k-j} \times g_{j}) x^k
F(x)×G(x)=k=0∑∞(j=0∑kfk−j×gj)xk由于最高次项为
x
n
−
1
x^{n-1}
xn−1,我们可以在模
x
n
x^n
xn的意义下运算,即
F
(
x
)
×
G
(
x
)
=
∑
k
=
0
n
−
1
(
∑
j
=
0
k
f
k
−
j
×
g
j
)
x
k
F(x) \times G(x) = \sum^{n-1}_{k=0}( \sum^{k}_{j=0} f_{k-j} \times g_{j}) x^k
F(x)×G(x)=k=0∑n−1(j=0∑kfk−j×gj)xk
k
>
0
k \gt 0
k>0时
∑
j
=
0
k
f
k
−
j
×
g
j
=
f
k
\sum^{k}_{j=0} f_{k-j} \times g_{j} = f_k
∑j=0kfk−j×gj=fk;
k
=
0
k = 0
k=0时
∑
j
=
0
k
f
k
−
j
×
g
j
=
0
\sum^{k}_{j=0} f_{k-j} \times g_{j} = 0
∑j=0kfk−j×gj=0。
所以
F
(
x
)
×
G
(
x
)
=
∑
k
=
1
n
−
1
f
k
x
k
F(x) \times G(x) = \sum^{n-1}_{k=1} f_k x^k
F(x)×G(x)=∑k=1n−1fkxk,刚好和
F
(
x
)
F(x)
F(x)差了一个
f
0
f_0
f0,即
F
(
x
)
×
G
(
x
)
+
f
0
=
F
(
x
)
F(x) \times G(x) + f_0 = F(x)
F(x)×G(x)+f0=F(x)。
不难得出:
F
(
x
)
=
f
0
1
−
G
(
x
)
F(x)=\frac{f_0}{1-G(x)}
F(x)=1−G(x)f0这样整个过程就完成了,时间复杂度
Θ
(
n
log
2
n
)
\varTheta(n \log^2n)
Θ(nlog2n)。
#include <iostream>
#include <cstdio>
#include <algorithm>
#include <cmath>
#include <cstring>
#define ll long long
using namespace std;
const int N = 1 << 22;
const int g = 3 , gi = 332748118 , mod = 998244353;
ll qw( ll a , ll b ) {
ll ans = 1;
while ( b ) {
if( b & 1 ) {
ans = ans * a % mod;
}
a = a * a % mod;
b >>= 1;
}
return ans;
}
int rev[N];
int n;
void pre( int bit ) {
for ( int i = 0 ; i < ( 1 << bit ) ; ++i ) {
rev[i] = (rev[i>>1]>>1)|((i&1)<<(bit - 1));
}
}
void NTT( ll *F , int len , int on ) {
for ( int i = 0 ; i < len ; ++i ) {
if ( i < rev[i] ) {
swap( F[i] , F[rev[i]] );
}
}
for ( int i = 2 ; i <= len ; i <<= 1 ) {
ll gn = qw( on ? g : gi , ( mod - 1 ) / ( i ) );
for ( int j = 0 ; j <= len - 1 ; j += i ) {
ll gg = 1;
for ( int k = j ; k < j + i / 2 ; ++k ) {
ll u = F[k];
ll v = gg * F[k + i / 2] % mod;
F[k] = (u + v) % mod;
F[k + i / 2] = ( u - v + mod ) % mod;
gg = gg * gn % mod;
}
}
}
return;
}
ll ta[N] , tb[N];
void solve( int len , ll *a , ll *b ) {
if( len == 1 ) {
b[0] = qw( a[0] , mod - 2 );
return;
}
solve( ( len + 1 ) >> 1 , a , b );
int l = 1;
int bit = 0;
while ( l <= len + n ) {
l <<= 1;
++bit;
}
pre( bit );
for ( int i = 0 ; i < l ; ++i ) {
ta[i] = a[i];
tb[i] = ( i < ( ( len + 1 ) >> 1 ) ? b[i] : 0 );
}
NTT( ta , l , 1 );
NTT( tb , l , 1 );
for ( int i = 0 ; i < l ; ++i ) {
ta[i] = tb[i] * ( ( ( 2 - ta[i] * tb[i] ) % mod + mod ) % mod ) % mod;
}
NTT( ta , l , 0 );
ll inv = qw( l , mod - 2 );
for ( int i = 0 ; i < len ; ++i ) {
b[i] = ta[i] * inv % mod;
}
}
ll a[N] , b[N];
int main(){
scanf("%d",&n);
for ( int i = 1 ; i < n ; ++i ) {
scanf("%lld",&a[i]);
a[i] = -a[i];
}
a[0] = 1;
solve( n , a , b );
for ( int i = 0 ; i < n ; ++i ) {
printf("%lld ",b[i] * 1ll);
}
return 0;
}
第二种解法(CDQ分治)
对于式子
f
i
=
∑
j
=
1
i
f
i
−
j
g
j
f_{i}= \sum_{j=1}^{i}f_{i-j}g_{j}
fi=∑j=1ifi−jgj,我们已知
f
0
f_0
f0,那么可以进一步求出
f
1
f_1
f1,然后求出
f
2
f_2
f2,
⋯
\cdots
⋯,一直到
f
n
−
1
f_{n-1}
fn−1,这样时间复杂度
Θ
(
n
2
)
\varTheta(n^2)
Θ(n2)。
我们考虑CDQ分治,目前要求解
i
∈
[
l
,
r
]
i \in [l,r]
i∈[l,r]所有
f
i
f_i
fi的解。
假设我们已经解得了
f
l
⋯
f
m
i
d
f_l \cdots f_{mid}
fl⋯fmid,下一步考虑如何快速解得
f
m
i
d
+
1
⋯
f
r
f_{mid + 1} \cdots f_r
fmid+1⋯fr。
我们可以考虑
f
l
⋯
f
m
i
d
f_l \cdots f_{mid}
fl⋯fmid对
f
m
i
d
+
1
⋯
f
r
f_{mid + 1} \cdots f_r
fmid+1⋯fr的贡献。
设
T
i
=
∑
j
=
l
m
i
d
f
i
−
j
g
j
T_i = \sum^{mid}_{j=l}f_{i-j}g_{j}
Ti=∑j=lmidfi−jgj。
我们用
f
l
⋯
f
m
i
d
f_l \cdots f_{mid}
fl⋯fmid和
g
0
⋯
g
r
−
l
g_0 \cdots g_{r-l}
g0⋯gr−l去卷积,就可以得到
T
i
T_i
Ti序列,累加到
f
m
i
d
+
1
⋯
f
r
f_{mid + 1} \cdots f_r
fmid+1⋯fr上面。
f
m
i
d
+
1
=
∑
j
=
1
m
i
d
+
1
f
i
−
j
g
j
=
∑
j
=
1
l
−
1
f
i
−
j
g
j
+
∑
j
=
l
m
i
d
f
i
−
j
g
j
f_{mid + 1} = \sum_{j=1}^{mid + 1} f_{i-j}g_{j} = \sum_{j=1}^{l - 1} f_{i-j}g_{j} + \sum_{j=l}^{mid} f_{i-j}g_{j}
fmid+1=j=1∑mid+1fi−jgj=j=1∑l−1fi−jgj+j=l∑midfi−jgj
前面的和式在前面的分治已经累加了贡献,后半部分就是
T
i
T_i
Ti,加上即可。
我们就这样一直分治下去,时间复杂度
Θ
(
n
log
2
n
)
\varTheta(n \log^2n)
Θ(nlog2n)。
#include <iostream>
#include <cstdio>
#include <algorithm>
#include <cmath>
#include <cstring>
#define ll long long
using namespace std;
const int N = 1 << 22;
const int g = 3 , gi = 332748118 , mod = 998244353;
ll qw( ll a , ll b ) {
ll ans = 1;
while ( b ) {
if( b & 1 ) {
ans = ans * a % mod;
}
a = a * a % mod;
b >>= 1;
}
return ans;
}
int rev[N];
int n;
void pre( int bit ) {
for ( int i = 0 ; i < ( 1 << bit ) ; ++i ) {
rev[i] = (rev[i>>1]>>1)|((i&1)<<(bit - 1));
}
}
void NTT( ll *F , int len , int on ) {
for ( int i = 0 ; i < len ; ++i ) {
if ( i < rev[i] ) {
swap( F[i] , F[rev[i]] );
}
}
for ( int i = 2 ; i <= len ; i <<= 1 ) {
ll gn = qw( on ? g : gi , ( mod - 1 ) / ( i ) );
for ( int j = 0 ; j <= len - 1 ; j += i ) {
ll gg = 1;
for ( int k = j ; k < j + i / 2 ; ++k ) {
ll u = F[k];
ll v = gg * F[k + i / 2] % mod;
F[k] = (u + v) % mod;
F[k + i / 2] = ( u - v + mod ) % mod;
gg = gg * gn % mod;
}
}
}
return;
}
ll f[N] , gg[N];
void mul( ll *a , ll *b , int bit ) {
pre( bit );
int len = ( 1 << bit );
NTT( a , len , 1 );
NTT( b , len , 1 );
for ( int i = 0 ; i < len ; ++i ) {
a[i] = a[i] * b[i] % mod;
}
NTT( a , len , 0 );
ll inv = qw( (ll)len , mod - 2 );
for ( int i = 0 ; i < len ; ++i ) {
a[i] = a[i] * inv % mod;
}
}
ll a[N] , b[N];
void solve( int l , int r ) {
if( l == r ) {
return;
}
int mid = ( l + r ) >> 1;
solve( l , mid );
int bit = 0;
int len = 1;
while ( len <= ( mid - l ) + ( r - l ) ) {
len <<= 1;
++bit;
}
for ( int i = 0 ; i < len ; ++i ) {
a[i] = b[i] = 0;
}
for ( int i = l ; i <= mid ; ++i ) {
a[i - l] = f[i];
}
for ( int i = 0 ; i <= r - l ; ++i ) {
b[i] = gg[i];
}
mul( a , b , bit );
for( int i = mid + 1 ; i <= r ; ++i ) {
f[i] = ( f[i] + a[i - l] ) % mod;
}
solve( mid + 1 , r );
return;
}
int main(){
scanf("%d",&n);
for ( int i = 1 ; i < n ; ++i ) {
scanf("%lld",&gg[i]);
}
gg[0] = 0;
f[0] = 1;
solve( 0 , n - 1 );
for ( int i = 0 ; i < n ; ++i ) {
printf("%lld ",f[i]);
}
return 0;
}
其实多项式求逆法时间效率上强于CDQ分治法。