题目
https://www.luogu.com.cn/problem/P4721
给定序列
g
1
,
g
2
,
.
.
g
n
−
1
g_1,g_2,..g_{n-1}
g1,g2,..gn−1,求
f
0
,
f
1
.
.
.
f
n
−
1
f_0,f_1...f_{n-1}
f0,f1...fn−1
f
i
=
∑
j
=
1
i
f
j
g
i
−
j
m
o
d
998244353
f_i=\sum_{j=1}^{i}f_jg_{i-j}\ mod\ 998244353
fi=j=1∑ifjgi−j mod 998244353
f
0
=
1
f_0=1
f0=1
思路
假设现在算区间
[
l
,
r
]
[l,r]
[l,r]的
f
f
f,先算区间
[
l
,
m
i
d
]
[l,mid]
[l,mid]的
f
f
f,算完后求出区间
[
l
,
m
i
d
]
[l,mid]
[l,mid]对
x
∈
[
m
i
d
+
1
,
r
]
x\in[mid+1,r]
x∈[mid+1,r]的贡献。
f
x
=
∑
i
=
l
m
i
d
f
i
g
x
−
i
f_x=\sum_{i=l}^{mid}f_ig_{x-i}
fx=i=l∑midfigx−i
然后再算区间
[
m
i
d
+
1
,
r
]
[mid+1,r]
[mid+1,r]
具体计算见https://blog.csdn.net/qq_43520313/article/details/109322001
#include<bits/stdc++.h>
#define ll long long
using namespace std;
const double Pi=acos(-1.0);
const int N=2100009;
const ll mod=998244353,G=3;//G是mod的原根
ll a[N],b[N],c[N],d[N],f[N],p[N],p1[N],Gi,_inv;//Gi是原根的逆元,inv是lim的逆元
int n,m,bit,lim,r[N];//lim表示当前运算的长度
ll qpow(ll a,ll b){ll res=1;a%=mod;while(b){if(b&1)res=res*a%mod;a=a*a%mod;b>>=1;}return res;}
void _init(int n)//多项式可能最长的长度,只需初始化一次
{
Gi=qpow(G,mod-2);
for(int i=1;i<n;i<<=1)p[i]=qpow(G,(mod-1)/(i<<1)),p1[i]=qpow(Gi,(mod-1)/(i<<1));
}
void init(int n,int m)//n,m表示最高长度,每次运算调用一次
{
lim=1,bit=0;
while(lim<n+m-1)lim<<=1,bit++;
_inv=qpow(lim,mod-2);
for(int i=0;i<lim;i++)r[i]=(r[i>>1]>>1)|((i&1)<<(bit-1));
}
void NTT(ll *a,int type)
{
for(int i=0;i<lim;i++)if(i<r[i])swap(a[i],a[r[i]]);
for(int mid=1;mid<lim;mid<<=1)
{
ll W=(type==1?p[mid]:p1[mid]);
for(int r=mid<<1,j=0;j<lim;j+=r)
{
ll e=1;
for(int k=0;k<mid;k++,e=e*W%mod)
{
ll x=a[j+k],y=e*a[j+k+mid]%mod;
a[j+k]=(x+y)%mod,a[j+k+mid]=(x-y+mod)%mod;
}
}
}
if(type==-1)for(int i=0;i<lim;i++)a[i]=a[i]*_inv%mod;
}
void solve(int l,int r) {
if(l==r)return ;
int mid=(l+r)>>1;
solve(l,mid);
for(int i=0; i<=mid-l; i++)b[i]=f[i+l];
for(int i=mid-l+1; i<=r-l; i++)b[i]=0;
for(int i=0; i<=r-l; i++)c[i]=a[i];
init(r-l+1,r-l+1);
NTT(b,1),NTT(c,1);
for(int i=0; i<lim; i++)d[i]=b[i]*c[i]%mod;
NTT(d,-1);
for(int i=mid+1; i<=r; i++)f[i]=(f[i]+d[i-l])%mod;
for(int i=0; i<lim; i++)b[i]=c[i]=d[i]=0;
solve(mid+1,r);
}
int main()
{
_init(N);
scanf("%d",&n);
for(int i=1; i<=n-1; i++)scanf("%lld",&a[i]);
f[0]=1;
solve(0,n-1);
for(int i=0; i<n; i++)printf("%lld ",f[i]);
return 0;
}