背景:
多项式全家桶
eating...
\text{eating...}
eating...
题目传送门:
https://www.luogu.org/problemnew/show/P5205
题意:
求一个多项式
G
(
x
)
G(x)
G(x),使得:
G
2
(
x
)
≡
F
(
x
)
(
m
o
d
  
x
n
)
G^2(x)≡F(x)(\mod x^n)
G2(x)≡F(x)(modxn)。
思路 1 1 1:
两边同时去
ln
\ln
ln,得:
ln
G
2
(
x
)
≡
ln
F
(
x
)
(
m
o
d
  
x
n
)
2
ln
G
(
x
)
≡
ln
F
(
x
)
(
m
o
d
  
x
n
)
ln
G
(
x
)
≡
ln
F
(
x
)
2
(
m
o
d
  
x
n
)
\begin{aligned}\ln G^2(x)&≡\ln F(x)(\mod x^n)\\ 2\ln G(x)&≡\ln F(x)(\mod x^n)\\ \ln G(x)&≡\frac{\ln F(x)}{2}(\mod x^n)\end{aligned}
lnG2(x)2lnG(x)lnG(x)≡lnF(x)(modxn)≡lnF(x)(modxn)≡2lnF(x)(modxn)
因此对
F
F
F取
ln
\ln
ln,得到的结果除以
2
2
2,最后取
exp
\text{exp}
exp即可。
为什么是取
exp
\text{exp}
exp,因为:
G
(
x
)
=
e
ln
G
(
x
)
G(x)=e^{\ln G(x)}
G(x)=elnG(x)。
以后还是能用
int
\text{int}
int,就不用
long long
\text{long long}
long long吧,
40pts
\text{40pts}
40pts与
100pts
\text{100pts}
100pts的差距。
代码 1 1 1:
#include<cstdio>
#include<cstring>
#include<algorithm>
#define LL long long
const LL mod=998244353,G=3,inv_G=332748118;
using namespace std;
int a[1000010],b[1000010],f[1000010],g[1000010],g1[1000010],g2[1000010],g3[1000010];
int limit,n,l,r[1000010];
int dg(int x,int k)
{
if(!k) return 1;
int op=dg(x,k>>1);
if(k&1) return (LL)op*op%mod*x%mod; else return (LL)op*op%mod;
}
int inv(int x)
{
return dg(x,mod-2);
}
void dao(int *f,int *g,int n)
{
for(int i=1;i<n;i++)
g[i-1]=(LL)i*f[i]%mod;
g[n-1]=0;
}
void jifen(int *f,int *g,int n)
{
for(int i=1;i<n;i++)
g[i]=(LL)f[i-1]*inv(i)%mod;
g[0]=0;
}
void init(int n)
{
limit=1,l=0;
while(limit<(n<<1))
limit<<=1,l++;
for(int i=1;i<limit;i++)
r[i]=((r[i>>1]>>1)|((i&1)<<(l-1)));
}
void NTT(int *now,int limit,int op)
{
for(int i=0;i<limit;i++)
if(i<r[i]) swap(now[i],now[r[i]]);
for(int mid=1;mid<limit;mid<<=1)
{
int wn=dg(op==1?G:inv_G,(mod-1)/(mid<<1));
for(int j=0;j<limit;j+=(mid<<1))
{
int w=1;
for(int k=0;k<mid;k++,w=((LL)w*wn)%mod)
{
int x=now[j+k],y=(LL)w*now[j+k+mid]%mod;
now[j+k]=(x+y)%mod;
now[j+k+mid]=(x-y+mod)%mod;
}
}
}
}
void dft(int *f,int n,int limit)
{
NTT(f,limit,-1);
int INV=inv(limit);
for(int i=0;i<n;i++)
f[i]=(LL)f[i]*INV%mod;
}
void poly_inv(int *f,int *g,int n)
{
if(n==1)
{
g[0]=inv(f[0]);
return;
}
poly_inv(f,g,(n+1)>>1);
init(n);
memset(a,0,sizeof(a));
memset(b,0,sizeof(b));
for(int i=0;i<n;i++)
a[i]=f[i],b[i]=g[i];
NTT(a,limit,1),NTT(b,limit,1);
for(int i=0;i<limit;i++)
b[i]=(LL)b[i]*((2ll-(LL)a[i]*b[i]%mod+mod)%mod)%mod;
dft(b,n,limit);
for(int i=0;i<n;i++)
g[i]=b[i];
}
void poly_ln(int *f,int n)
{
dao(f,g1,n);
poly_inv(f,g2,n);
init(n);
NTT(g1,limit,1),NTT(g2,limit,1);
for(int i=0;i<limit;i++)
g1[i]=(LL)g1[i]*g2[i]%mod;
dft(g1,n,limit);
jifen(g1,g2,n);
}
void poly_exp(int *f,int *g,int n)
{
if(n==1)
{
g[0]=1;
return;
}
poly_exp(f,g,(n+1)>>1);
memset(g1,0,sizeof(g1));
memset(g2,0,sizeof(g2));
poly_ln(g,n);
init(n);
memset(a,0,sizeof(a));
memset(b,0,sizeof(b));
for(int i=0;i<n;i++)
a[i]=g[i],b[i]=((LL)(!i)-g2[i]+f[i]+mod)%mod;
NTT(a,limit,1),NTT(b,limit,1);
for(int i=0;i<limit;i++)
a[i]=(LL)a[i]*b[i]%mod;
dft(a,n,limit);
for(int i=0;i<n;i++)
g[i]=a[i];
}
void poly_sqrt(int *f,int *g,int n)
{
poly_ln(f,n);
int inv2=inv(2);
for(int i=0;i<n;i++)
g3[i]=(LL)g2[i]*inv2%mod;
memset(g1,0,sizeof(g1));
memset(g2,0,sizeof(g2));
poly_exp(g3,g,n);
}
int main()
{
scanf("%d",&n);
for(int i=0;i<n;i++)
scanf("%d",&f[i]);
poly_sqrt(f,g,n);
for(int i=0;i<n;i++)
printf("%d ",g[i]);
}
思路 2 2 2:
我们像求逆一样推式子。
G
2
(
x
)
=
F
(
x
)
(
m
o
d
  
x
n
)
G^2(x)=F(x)(\mod x^n)
G2(x)=F(x)(modxn)
假设我们已经搞定了
A
2
(
x
)
≡
F
(
x
)
(
m
o
d
  
x
⌈
n
2
⌉
)
A^2(x)≡F(x)(\mod x^{\lceil\frac{n}{2}\rceil})
A2(x)≡F(x)(modx⌈2n⌉)。
两式相减,有:
G
2
(
x
)
−
A
2
(
x
)
≡
0
(
m
o
d
  
x
⌈
n
2
⌉
)
G^2(x)-A^2(x)≡0(\mod x^{{\lceil\frac{n}{2}\rceil}})
G2(x)−A2(x)≡0(modx⌈2n⌉)
两边同时平方,有:
(
G
2
(
x
)
−
A
2
(
x
)
)
2
≡
0
(
m
o
d
  
x
n
)
\big(G^2(x)-A^2(x)\big)^2≡0(\mod x^{n})
(G2(x)−A2(x))2≡0(modxn)
两边同时加上
4
A
2
(
x
)
G
2
(
x
)
4A^2(x)G^2(x)
4A2(x)G2(x),有:
(
G
2
(
x
)
+
A
2
(
x
)
)
2
≡
4
A
2
(
x
)
G
2
(
x
)
(
m
o
d
  
x
n
)
\big(G^2(x)+A^2(x)\big)^2≡4A^2(x)G^2(x)(\mod x^{n})
(G2(x)+A2(x))2≡4A2(x)G2(x)(modxn)
两边同时开方,有:
G
2
(
x
)
+
A
2
(
x
)
≡
2
A
(
x
)
G
(
x
)
(
m
o
d
  
x
n
)
G^2(x)+A^2(x)≡2A(x)G(x)(\mod x^{n})
G2(x)+A2(x)≡2A(x)G(x)(modxn)
F ( x ) + A 2 ( x ) ≡ 2 A ( x ) G ( x ) ( m o d    x n ) F(x)+A^2(x)≡2A(x)G(x)(\mod x^{n}) F(x)+A2(x)≡2A(x)G(x)(modxn)
G ( x ) = A 2 ( x ) + F ( x ) 2 A ( x ) G(x)=\frac{A^2(x)+F(x)}{2A(x)} G(x)=2A(x)A2(x)+F(x)
套一个多项式求逆即可。
时间上更优秀些(常数更小)。
代码:
#include<cstdio>
#include<cstring>
#include<algorithm>
#define LL long long
const LL mod=998244353,G=3,inv_G=332748118;
using namespace std;
int a[1000010],b[1000010],f[1000010],g[1000010],h1[1000010],h2[1000010],h3[1000010];
int limit,n,l,r[1000010];
int dg(int x,int k)
{
if(!k) return 1;
int op=dg(x,k>>1);
if(k&1) return (LL)op*op%mod*x%mod; else return (LL)op*op%mod;
}
int inv(int x)
{
return dg(x,mod-2);
}
void init(int n)
{
limit=1,l=0;
while(limit<(n<<1))
limit<<=1,l++;
for(int i=1;i<limit;i++)
r[i]=((r[i>>1]>>1)|((i&1)<<(l-1)));
}
void NTT(int *now,int limit,int op)
{
for(int i=0;i<limit;i++)
if(i<r[i]) swap(now[i],now[r[i]]);
for(int mid=1;mid<limit;mid<<=1)
{
int wn=dg(op==1?G:inv_G,(mod-1)/(mid<<1));
for(int j=0;j<limit;j+=(mid<<1))
{
int w=1;
for(int k=0;k<mid;k++,w=((LL)w*wn)%mod)
{
int x=now[j+k],y=(LL)w*now[j+k+mid]%mod;
now[j+k]=(x+y)%mod;
now[j+k+mid]=(x-y+mod)%mod;
}
}
}
}
void dft(int *f,int n,int limit)
{
NTT(f,limit,-1);
int INV=inv(limit);
for(int i=0;i<n;i++)
f[i]=(LL)f[i]*INV%mod;
}
void poly_inv(int *f,int *g,int n)
{
if(n==1)
{
g[0]=inv(f[0]);
return;
}
poly_inv(f,g,(n+1)>>1);
init(n);
memset(a,0,sizeof(a));
memset(b,0,sizeof(b));
for(int i=0;i<n;i++)
a[i]=f[i],b[i]=g[i];
NTT(a,limit,1),NTT(b,limit,1);
for(int i=0;i<limit;i++)
b[i]=(LL)b[i]*((2ll-(LL)a[i]*b[i]%mod+mod)%mod)%mod;
dft(b,n,limit);
for(int i=0;i<n;i++)
g[i]=b[i];
}
void poly_sqrt(int *f,int *g,int n)
{
if(n==1)
{
g[0]=1;
return;
}
poly_sqrt(f,g,(n+1)>>1);
init(n);
memset(a,0,sizeof(a));
for(int i=0;i<n;i++)
a[i]=g[i];
NTT(a,limit,1);
for(int i=0;i<limit;i++)
a[i]=(LL)a[i]*a[i]%mod;
NTT(a,limit,-1);
int INV=inv(limit);
memset(h1,0,sizeof(h1));
memset(h2,0,sizeof(h2));
memset(h3,0,sizeof(h3));
for(int i=0;i<n;i++)
h1[i]=((LL)a[i]*INV%mod+f[i])%mod;
for(int i=0;i<n;i++)
h2[i]=2ll*g[i]%mod;
poly_inv(h2,h3,n);
NTT(h1,limit,1),NTT(h3,limit,1);
for(int i=0;i<limit;i++)
h1[i]=(LL)h1[i]*h3[i]%mod;
NTT(h1,limit,-1);
for(int i=0;i<n;i++)
g[i]=(LL)h1[i]*INV%mod;
}
int main()
{
scanf("%d",&n);
for(int i=0;i<n;i++)
scanf("%d",&f[i]);
poly_sqrt(f,g,n);
for(int i=0;i<n;i++)
printf("%d ",g[i]);
}