题目描述
给出 n − 1 n-1 n−1 次多项式 A ( x ) A(x) A(x),求一个   m o d     x n \bmod{\:x^n} modxn 下的多项式 B ( x ) B(x) B(x),满足 B ( x ) ≡ e A ( x ) B(x)≡e^{A(x)} B(x)≡eA(x).
输入输出格式
输入格式:
第一行一个整数
n
n
n.
下一行有 n n n 个整数,依次表示多项式的系数 a 0 , a 1 , ⋯   , a n − 1 a_0, a_1, \cdots, a_{n-1} a0,a1,⋯,an−1.
保证 a 0 = 0 a_0=0 a0=0.
输出格式:
输出
n
n
n 个整数,表示答案多项式中的系数
a
0
,
a
1
,
⋯
 
,
a
n
−
1
a_0, a_1, \cdots, a_{n-1}
a0,a1,⋯,an−1.
输入输出样例
输入样例#1:
6
0 927384623 817976920 427326948 149643566 610586717
输出样例#1:
1 927384623 878326372 3882 273455637 998233543
说明
对于 100 % 100\% 100% 的数据, n ≤ 1 0 5 n≤10^5 n≤105.
分析:
使用牛顿迭代,设
G
(
x
)
G(x)
G(x)为在
 
m
o
d
 
 
x
n
\bmod{\:x^n}
modxn意义下的多项式,
G
′
(
x
)
G'(x)
G′(x)为在
 
m
o
d
 
 
x
⌈
n
2
⌉
\bmod{\:x^{\lceil\frac{n}{2}\rceil}}
modx⌈2n⌉意义下的多项式,有
G
(
x
)
=
G
′
(
x
)
(
F
(
x
)
−
l
n
(
G
′
(
X
)
)
+
1
)
G(x)=G'(x)(F(x)-ln(G'(X))+1)
G(x)=G′(x)(F(x)−ln(G′(X))+1)
递归求即可,
n
n
n要在递归前开两倍。
代码:
// luogu-judger-enable-o2
#include <iostream>
#include <cstdio>
#include <cmath>
#define LL long long
const int maxn=7e5+7;
const LL mod=998244353;
const LL G=3;
using namespace std;
int n,len,r[maxn];
LL f[maxn],g[maxn],h[maxn],s[maxn],inv[maxn],w[maxn],a[maxn],b[maxn],c[maxn];
LL power(LL x,LL y)
{
if (y==1) return x;
LL c=power(x,y/2);
c=(c*c)%mod;
if (y%2) c=(c*x)%mod;
return c;
}
void ntt(LL *a,int f)
{
for (int i=0;i<len;i++)
{
if (i<r[i]) swap(a[i],a[r[i]]);
}
w[0]=1;
for (int i=2;i<=len;i*=2)
{
LL wn;
if (f==1) wn=power(G,(LL)(mod-1)/i);
else wn=power(G,(LL)(mod-1)-(mod-1)/i);
for (int j=i/2;j>=0;j-=2) w[j]=w[j/2];
for (int j=1;j<i/2;j+=2) w[j]=(w[j-1]*wn)%mod;
for (int j=0;j<len;j+=i)
{
for (int k=0;k<i/2;k++)
{
LL u=a[j+k],v=a[j+k+i/2]*w[k]%mod;
a[j+k]=(u+v)%mod;
a[j+k+i/2]=(u+mod-v)%mod;
}
}
}
if (f==-1)
{
LL inv=power(len,mod-2);
for (int i=0;i<len;i++) a[i]=a[i]*inv%mod;
}
}
void NTT(LL *x,LL *y,LL *z,int n,int m)
{
len=1;
int k=0;
while (len<=(n+m)) len*=2,k++;
for (int i=0;i<len;i++)
{
r[i]=(r[i>>1]>>1)|((i&1)<<(k-1));
}
for (int i=0;i<len;i++)
{
if (i<n) a[i]=x[i]; else a[i]=0;
if (i<m) b[i]=y[i]; else b[i]=0;
}
ntt(a,1); ntt(b,1);
for (int i=0;i<len;i++) z[i]=(a[i]*b[i])%mod;
ntt(z,-1);
}
void getinv(LL *f,LL *g,int deg)
{
if (deg==1)
{
g[0]=power(f[0],mod-2);
return;
}
int d=(deg+1)/2;
getinv(f,g,d);
NTT(f,g,c,deg,d);
c[0]=(2+mod-c[0])%mod;
for (int i=1;i<deg;i++) c[i]=(mod-c[i])%mod;
NTT(c,g,g,deg,d);
for (int i=deg;i<len;i++) g[i]=0;
}
void tran(LL *a,LL *b,int deg)
{
for (int i=1;i<deg;i++)
b[i-1]=(i*a[i])%mod;
b[n-1]=0;
}
void intran(LL *a,LL *b,int deg)
{
for (int i=1;i<deg;i++) b[i]=a[i-1]*power(i,mod-2)%mod;
b[0]=0;
}
void ln(LL *f,LL *g,int n)
{
getinv(f,inv,n);
tran(f,f,n);
NTT(f,inv,f,n,n);
intran(f,g,n);
}
void solve(LL *f,LL *g,int deg)
{
if (deg==1)
{
g[0]=1;
return;
}
int mid=(deg+1)/2;
solve(f,g,mid);
for (int i=0;i<mid;i++) s[i]=g[i];
ln(s,h,mid);
h[0]=(f[0]+1+mod-h[0])%mod;
for (int i=1;i<mid;i++) h[i]=(f[i]+mod-h[i])%mod;
NTT(h,g,g,mid,mid);
for (int i=deg;i<len;i++) g[i]=h[i]=0;
}
int main()
{
scanf("%d",&n);
for (int i=0;i<n;i++) scanf("%lld",&f[i]);
solve(f,g,n<<1);
for (int i=0;i<n;i++) printf("%lld ",g[i]);
}