P4721 【模板】分治 FFT
题意:
-
给定序列 g 1 … n − 1 g_{1\dots n - 1} g1…n−1,求序列 f 0 … n − 1 f_{0\dots n - 1} f0…n−1
其中 f i = ∑ j = 1 i f i − j g j f_i=\sum_{j=1}^if_{i-j}g_j fi=∑j=1ifi−jgj ,边界为 f 0 = 1 f_0=1 f0=1
分析:
-
若直接按照题目给的式子展开求和, O ( n 2 ) O(n^2) O(n2) ,会超时
-
考虑分治:
先将f_i展开来观察一下:
f 1 = g 1 f 0 f_1=g_1f_0 f1=g1f0
f 2 = g 1 f 1 + g 2 f 0 f_2=g_1f_1+g_2f_0 f2=g1f1+g2f0
f 3 = g 1 f 2 + g 2 f 1 + g 3 f 0 f_3=g_1f_2+g_2f_1+g_3f_0 f3=g1f2+g2f1+g3f0
f 4 = g 1 f 3 + g 2 f 2 + g 3 f 1 + g 4 f 0 f_4=g_1f_3+g_2f_2+g_3f_1+g_4f_0 f4=g1f3+g2f2+g3f1+g4f0
f 5 = g 1 f 4 + g 2 f 3 + g 3 f 2 + g 4 f 1 + g 5 f 0 f_5=g_1f_4+g_2f_3+g_3f_2+g_4f_1+g_5f_0 f5=g1f4+g2f3+g3f2+g4f1+g5f0
考虑分治的话,要考虑区间 [ l , r ] [l,r] [l,r] 中, [ l , m i d ] [l,mid] [l,mid] 和 [ m i d + 1 , r ] [mid+1,r] [mid+1,r] 之间的关系
容易发现, [ l , m i d ] [l,mid] [l,mid] 会对 [ m i d + 1 , r ] [mid+1,r] [mid+1,r] 产生贡献(令其为 v a l i val_i vali )
v a l i = ∑ j = l m i d g i − j f j val_i=\sum_{j=l}^{mid}g_{i-j}f_{j} vali=∑j=lmidgi−jfj , i ∈ [ m i d + 1 , r ] i\in[mid+1,r] i∈[mid+1,r]
v a l i val_i vali 是由两个函数卷积得出的,将 g i − j f j g_{i-j}f_{j} gi−jfj 中的 ( i − j ) , j (i-j),j (i−j),j 看成多项式的项数(次数),则其系数为 f j f_{j} fj ,可通过 N T T NTT NTT 加速算出
-
分治过程中区间之间的关系:
- 令 l = 0 l=0 l=0 , r = 7 r=7 r=7
-
(
l
,
m
i
d
)
−
−
>
(
m
i
d
+
1
,
r
)
(l,mid) -->(mid+1,r)
(l,mid)−−>(mid+1,r)
- ( 0 , 0 ) − − > ( 1 , 1 ) (0,0)-->(1,1) (0,0)−−>(1,1)
- ( 0 , 1 ) − − > ( 2 , 3 ) (0,1)-->(2,3) (0,1)−−>(2,3)
- ( 2 , 2 ) − − > ( 3 , 3 ) (2,2)-->(3,3) (2,2)−−>(3,3)
- ( 0 , 3 ) − − > ( 4 , 7 ) (0,3)-->(4,7) (0,3)−−>(4,7)
- ( 4 , 4 ) − − > ( 5 , 5 ) (4,4)-->(5,5) (4,4)−−>(5,5)
- ( 4 , 5 ) − − > ( 6 , 7 ) (4,5)-->(6,7) (4,5)−−>(6,7)
- ( 6 , 6 ) − − > ( 7 , 7 ) (6,6)-->(7,7) (6,6)−−>(7,7)
-
f
i
=
f_i=
fi=过程中算得的
v
a
l
i
val_i
vali 之和
时间复杂度: O ( n l o g 2 n ) O(nlog^2n) O(nlog2n) , 分治一个 l o g log log
#include<bits/stdc++.h>
#define int long long
using namespace std;
const int N = (1<<20)+5, mo=998244353;
inline int binpow(int a,int b)
{
int ans=1;
while(b)
{
if(b&1) ans = ans*a%mo;
a = a*a%mo;
b >>= 1;
}
return ans;
}
int rev[N];
void ntt(int *a, int n, int inv)
{
for(int i=0;i<n;i++)
{
if(i < rev[i]) swap(a[i], a[rev[i]]);
}
for(int len=1;len<n;len<<=1)
{
int Wn = binpow(3, (mo-1)/(len<<1));
if(inv == -1) Wn = binpow(Wn, mo-2);
for(int i=0;i<n;i+=(len<< 1))
{
int w=1;
for(int j=0;j<len;j++, w = (w*Wn)%mo)
{
int x = a[i + j], y = w*a[i+j+len]%mo;
a[i+j] = (x+y)%mo; a[i+j+len] = (x-y+mo)%mo;
}
}
}
if(inv == -1)
{
int fg=binpow(n, mo-2);
for(int i=0;i<n;i++) a[i] = a[i]*fg%mo;
}
}
inline void mul(int *a,int *b,int k)
{
int s=1<<k;
for(int i=1;i<s;i++) rev[i]=(rev[i>>1]>>1)|((i&1)<<(k-1));
ntt(a,s,1); ntt(b,s,1);
for(int i=0;i<s;i++) a[i] = a[i]*b[i]%mo;
ntt(a,s,-1);
}
int g[N], a[N], b[N], f[N];
inline void binary(int l,int r)
{
if(r == l) return ;
int mid = (l+r)>>1;
binary(l, mid);
int sum=2, k=1;
while(sum <= (mid-l+r-l)) k++, sum <<= 1;
for(int i=0;i<sum;i++) a[i] = b[i] = 0;
for(int i=l;i<=mid;i++) a[i-l] = f[i];
for(int i=1;i<=r-l;i++) b[i] = g[i];
mul(a, b, k);
for(int i=mid+1;i<=r;i++) f[i] = (f[i]+a[i-l])%mo;
binary(mid+1, r);
}
signed main()
{
ios::sync_with_stdio(false);cin.tie(nullptr);cout.tie(nullptr);
int n;
cin>>n;
f[0] = 1;
for(int i=1;i<n;i++) cin>>g[i];
binary(0, n-1);
for(int i=0;i<n;i++) cout<<f[i]<<' ';
return 0;
}