前置知识
前言
分治 F F T FFT FFT是基于分治的算法,通过每次计算左区间对右区间的贡献,来降低 F F T FFT FFT的时间复杂度。
情景代入
给定序列 g 1 , g 2 … , g n − 1 g_1,g_2\dots,g_{n-1} g1,g2…,gn−1,求 f 0 , f 1 , … , f n − 1 f_0,f_1,\dots,f_{n-1} f0,f1,…,fn−1。
其中 f i = ∑ j = 1 i f i − j g j f_i=\sum\limits_{j=1}^if_{i-j}g_j fi=j=1∑ifi−jgj,边界为 f 0 = 1 f_0=1 f0=1。
答案对 998244353 998244353 998244353取模。
2 ≤ n ≤ 1 0 5 , 0 ≤ g i < 998244353 2\leq n\leq 10^5,0\leq g_i<998244353 2≤n≤105,0≤gi<998244353
分析
由 f f f的递推式可看出,它是其前面的项于多项式 g g g的卷积。多项式 f f f不能一次全部求出,因为其每一项都和前面的项有关。我们可以用分治,将左半区间对右半区间的贡献提前累加到右半区间。
设当前左半区间为 [ l , m i d ] [l,mid] [l,mid],右半区间为 [ m i d + 1 , r ] [mid+1,r] [mid+1,r], k ∈ [ m i d + 1 , r ] k\in[mid+1,r] k∈[mid+1,r],考虑左半区间对右半区间内的点 k k k的贡献。
v k = ∑ i = l m i d f i × g k − i v_k=\sum\limits_{i=l}^{mid}f_i\times g_{k-i} vk=i=l∑midfi×gk−i
设 F ( x ) = ∑ i = 0 m i d − l f i + m i d x i , G ( x ) = ∑ i = 0 r − l g i x i F(x)=\sum\limits_{i=0}^{mid-l}f_{i+mid}x_i,G(x)=\sum\limits_{i=0}^{r-l}g_ix^i F(x)=i=0∑mid−lfi+midxi,G(x)=i=0∑r−lgixi, H ( x ) = F ∗ G H(x)=F*G H(x)=F∗G,则
v k = H k − l v_k=H_{k-l} vk=Hk−l
所以左半区间 [ l , m i d ] [l,mid] [l,mid]对右半区间 [ m i d + 1 , r ] [mid+1,r] [mid+1,r]的贡献可以通过一次多项式乘法得出。
时间复杂度为 O ( n log 2 n ) O(n\log^2 n) O(nlog2n)。
code
#include<bits/stdc++.h>
using namespace std;
long long g[500005],f[500005],a1[500005],a2[500005];
const long long G=3,mod=998244353;
long long mi(long long t,long long v){
if(!v) return 1;
long long re=mi(t,v/2);
re=re*re%mod;
if(v&1) re=re*t%mod;
return re;
}
void ch(long long *a,int l){
for(int i=1,j=l/2,k;i<l-1;i++){
if(i<j) swap(a[i],a[j]);
k=l/2;
while(j>=k){
j-=k;k>>=1;
}
j+=k;
}
}
void ntt(long long *a,int l,int fl){
long long wn,w;
for(int i=2;i<=l;i<<=1){
if(fl==1) wn=mi(G,(mod-1)/i);
else wn=mi(G,mod-1-(mod-1)/i);
for(int j=0;j<l;j+=i){
w=1;
for(int k=j;k<j+i/2;k++,w=w*wn%mod){
long long t=a[k],u=w*a[k+i/2]%mod;
a[k]=(t+u)%mod;
a[k+i/2]=(t-u+mod)%mod;
}
}
}
if(fl==-1){
long long ny=mi(l,mod-2);
for(int i=0;i<l;i++) a[i]=a[i]*ny%mod;
}
}
void solve(int l,int r){
if(l==r){
f[l]=(f[l]+g[l])%mod;
return;
}
int mid=l+r>>1;
solve(l,mid);
int len=1;
while(len<r-l+1) len<<=1;
for(int i=0;i<len;i++) a1[i]=a2[i]=0;
for(int i=0;i<=mid-l;i++) a1[i]=f[i+l];
for(int i=0;i<=r-l;i++) a2[i]=g[i];
ch(a1,len);ch(a2,len);
ntt(a1,len,1);ntt(a2,len,1);
for(int i=0;i<len;i++){
a1[i]=a1[i]*a2[i]%mod;
}
ch(a1,len);
ntt(a1,len,-1);
for(int i=mid+1;i<=r;i++){
f[i]=(f[i]+a1[i-l])%mod;
}
solve(mid+1,r);
}
int main()
{
int n;
scanf("%d",&n);
for(int i=1;i<n;i++){
scanf("%lld",&g[i]);
}
solve(0,n-1);
f[0]=1;
for(int i=0;i<=n-1;i++){
printf("%lld ",f[i]);
}
return 0;
}