已知 g ( x ) g(x) g(x),你要求出 f ( x ) ≡ e g ( x ) ( m o d x n ) f(x)\equiv e^{g(x)} \pmod{x^n} f(x)≡eg(x)(modxn)。
前置知识:多项式求逆,多项式 ln \ln ln,泰勒展开。
分析
此题使用牛顿迭代法求解。
牛顿迭代可用来求函数的零点。可以快速求出 G ( F ( x ) ) = 0 G(F(x))=0 G(F(x))=0 的零点。下面是多项式牛顿迭代公式的推导。
类似于多项式求逆,递归求解。
如果我们已经求出一个
f
(
x
)
f(x)
f(x) 满足
G
(
f
(
x
)
)
≡
0
(
m
o
d
x
⌈
n
2
⌉
)
G(f(x))\equiv0\pmod{x^{\lceil\frac n2\rceil}}
G(f(x))≡0(modx⌈2n⌉)
要求出一个
F
(
x
)
F(x)
F(x) 满足
G
(
F
(
x
)
)
≡
0
(
m
o
d
x
n
)
G(F(x))\equiv0\pmod{x^n}
G(F(x))≡0(modxn)
容易得到
F
(
x
)
−
f
(
x
)
≡
0
(
m
o
d
x
⌈
n
2
⌉
)
F(x)-f(x)\equiv0\pmod{x^{\lceil\frac n2\rceil}}
F(x)−f(x)≡0(modx⌈2n⌉)
进一步得到
(
F
(
x
)
−
f
(
x
)
)
2
≡
0
(
m
o
d
x
n
)
(F(x)-f(x))^2\equiv0\pmod{x^n}
(F(x)−f(x))2≡0(modxn)
如果再给上式乘上 F ( x ) − f ( x ) F(x)-f(x) F(x)−f(x) 的若干次方,它仍然模 x n x^n xn 等于 0 0 0。
即对 i ≥ 2 i\ge2 i≥2,有 ( F ( x ) − f ( x ) ) i ≡ 0 ( m o d x n ) (F(x)-f(x))^i\equiv0\pmod{x^n} (F(x)−f(x))i≡0(modxn)
对
G
(
F
(
x
)
)
G(F(x))
G(F(x)) 在
f
(
x
)
f(x)
f(x) 处泰勒展开得
G
(
F
(
x
)
)
≡
∑
i
=
0
∞
G
(
i
)
(
f
(
x
)
)
(
F
(
x
)
−
f
(
x
)
)
i
i
!
(
m
o
d
x
n
)
G(F(x))\equiv\sum\limits_{i=0}^{\infty}\dfrac{G^{(i)}(f(x))(F(x)-f(x))^i}{i!}\pmod{x^n}
G(F(x))≡i=0∑∞i!G(i)(f(x))(F(x)−f(x))i(modxn)
由上面的结论得
G
(
F
(
x
)
)
≡
G
(
f
(
x
)
)
+
G
′
(
f
(
x
)
)
(
F
(
x
)
−
f
(
x
)
)
(
m
o
d
x
n
)
G(F(x))\equiv G(f(x))+G'(f(x))(F(x)-f(x))\pmod{x^n}
G(F(x))≡G(f(x))+G′(f(x))(F(x)−f(x))(modxn)
因为 G ( F ( x ) ) ≡ 0 ( m o d x n ) G(F(x))\equiv0\pmod{x^n} G(F(x))≡0(modxn)
所以
G
(
f
(
x
)
)
+
G
′
(
f
(
x
)
)
(
F
(
x
)
−
f
(
x
)
)
≡
0
(
m
o
d
x
n
)
G(f(x))+G'(f(x))(F(x)-f(x))\equiv0\pmod{x^n}
G(f(x))+G′(f(x))(F(x)−f(x))≡0(modxn)
整理得到
F
(
x
)
≡
f
(
x
)
−
G
(
f
(
x
)
)
G
′
(
f
(
x
)
)
(
m
o
d
x
n
)
F(x)\equiv f(x)-\dfrac{G(f(x))}{G'(f(x))}\pmod{x^n}
F(x)≡f(x)−G′(f(x))G(f(x))(modxn)
回到本题。
我们要求
f
(
x
)
f(x)
f(x) 满足
f
(
x
)
≡
e
g
(
x
)
(
m
o
d
x
n
)
f(x)\equiv e^{g(x)}\pmod{x^n}
f(x)≡eg(x)(modxn)
两边取对数,右边移到左边得
ln
f
(
x
)
−
g
(
x
)
≡
0
(
m
o
d
x
n
)
\ln f(x)-g(x)\equiv0\pmod{x^n}
lnf(x)−g(x)≡0(modxn)
由上面牛顿迭代的式子,令
G
(
f
(
x
)
)
=
ln
f
(
x
)
−
g
(
x
)
G(f(x))=\ln f(x)-g(x)
G(f(x))=lnf(x)−g(x)
则
F
(
x
)
≡
f
(
x
)
−
ln
f
(
x
)
−
g
(
x
)
(
ln
f
(
x
)
−
g
(
x
)
)
′
(
m
o
d
x
n
)
≡
f
(
x
)
−
ln
f
(
x
)
−
g
(
x
)
1
f
(
x
)
(
m
o
d
x
n
)
≡
f
(
x
)
(
1
−
ln
f
(
x
)
+
g
(
x
)
)
(
m
o
d
x
n
)
\begin{aligned} F(x)&\equiv f(x)-\dfrac{\ln f(x)-g(x)}{(\ln f(x)-g(x))'}\pmod{x^n}\\ &\equiv f(x)-\dfrac{\ln f(x)-g(x)}{\dfrac{1}{f(x)}}\pmod{x^n}\\ &\equiv f(x)(1-\ln f(x)+g(x))\pmod{x^n} \end{aligned}
F(x)≡f(x)−(lnf(x)−g(x))′lnf(x)−g(x)(modxn)≡f(x)−f(x)1lnf(x)−g(x)(modxn)≡f(x)(1−lnf(x)+g(x))(modxn)
这样,就可以递归求出 e g ( x ) e^{g(x)} eg(x) 了。
代码
#include<bits/stdc++.h>
using namespace std;
typedef long long ll;
const int N=(1<<18)+1;
const ll mod=998244353,g=3,inv2=499122177;
int len=1,n;
ll a1[N],w,wn,a[N],ans[N],invans[N],lnans[N],da[N],inva[N];
ll ksm(ll a,ll b)
{
ll ans=1;
while(b){
if(b&1) ans=ans*a%mod;
a=a*a%mod;
b>>=1;
}
return ans;
}
void change(ll num[])
{
for(int i=1,j=len/2;i<len-1;i++){
if(i<j) swap(num[i],num[j]);
int k=len/2;
while(j>=k) j-=k,k>>=1;
if(j<k) j+=k;
}
}
void ntt(ll num[],int fl)
{
for(int i=2;i<=len;i<<=1){
if(fl==1) wn=ksm(g,(mod-1)/i);
else wn=ksm(g,mod-1-(mod-1)/i);
for(int j=0;j<len;j+=i){
w=1;
for(int k=j;k<j+i/2;k++){
ll u=w*num[k+i/2]%mod,t=num[k];
num[k]=(t+u)%mod;
num[k+i/2]=(t-u+mod)%mod;
w=w*wn%mod;
}
}
}
if(fl==-1){
ll inv=ksm(len,mod-2);
for(int i=0;i<len;i++) num[i]=num[i]*inv%mod;
}
}
int read()
{
int sum=0,c=getchar();
while(c<48||c>57) c=getchar();
while(c>=48&&c<=57) sum=sum*10+c-48,c=getchar();
return sum;
}
void getinv(int n,ll a[],ll ans[])
{
if(n==1){ans[0]=ksm(a[0],mod-2);return;}
getinv((n+1)/2,a,ans);
len=1;
while(len<2*n) len*=2;
for(int i=0;i<n;i++) a1[i]=a[i];
for(int i=n;i<len;i++) a1[i]=0;
change(a1),change(ans);
ntt(a1,1),ntt(ans,1);
for(int i=0;i<len;i++) ans[i]=ans[i]*(2-ans[i]*a1[i]%mod+mod)%mod;
change(ans),ntt(ans,-1);
for(int i=n;i<len;i++) ans[i]=0;
}
void getln(int n,ll a[],ll ln[])
{
for(int i=1;i<n;i++) da[i-1]=a[i]*i;
da[n-1]=0;
memset(inva,0,sizeof(inva));
getinv(n,a,inva);
len=1;
while(len<2*n) len*=2;
change(da),change(inva);
ntt(da,1),ntt(inva,1);
for(int i=0;i<len;i++) ln[i]=da[i]*inva[i]%mod;
change(ln),ntt(ln,-1);
for(int i=len-1;i>=0;i--) ln[i+1]=ksm(i+1,mod-2)*ln[i]%mod;
for(int i=n;i<len;i++) ln[i]=0;
ln[0]=0;
}
void getsqrt(int n,ll a[],ll ans[])
{
if(n==1){ans[0]=a[0];return;}
getsqrt((n+1)/2,a,ans);
len=1;
while(len<2*n) len*=2;
memset(invans,0,sizeof(invans));
getinv(n,ans,invans);
for(int i=0;i<n;i++) a1[i]=a[i];
for(int i=n;i<len;i++) a1[i]=0;
change(a1),change(invans);
ntt(a1,1),ntt(invans,1);
for(int i=0;i<len;i++) a1[i]=a1[i]*invans[i]%mod;
change(a1),ntt(a1,-1);
for(int i=0;i<n;i++) ans[i]=(a1[i]+ans[i])*inv2%mod;
for(int i=n;i<len;i++) ans[i]=0;
}
void getexp(int n,ll a[],ll ans[])
{
if(n==1){ans[0]=1;return;}
getexp((n+1)/2,a,ans);
len=1;
while(len<2*n) len*=2;
getln(n,ans,lnans);
for(int i=0;i<n;i++) lnans[i]=(-lnans[i]+a[i]+mod)%mod;
lnans[0]++;
change(ans),change(lnans);
ntt(ans,1),ntt(lnans,1);
for(int i=0;i<len;i++) ans[i]=ans[i]*lnans[i]%mod;
change(ans),ntt(ans,-1);
for(int i=n;i<len;i++) ans[i]=0;
}
int main()
{
scanf("%d",&n);
for(int i=0;i<n;i++) scanf("%lld",&a[i]),a[i]%=mod;
getexp(n,a,ans);
for(int i=0;i<n;i++) printf("%lld ",ans[i]);
}