问题
给定一个 n − 1 n-1 n−1 次多项式 A ( x ) A(x) A(x) ,求一个 n − 1 n-1 n−1 次多项式 B ( x ) B(x) B(x),使得 B 2 ( x ) = A ( x ) B^2(x)=A(x) B2(x)=A(x)。
思路
考虑倍增。假设已经求出
n
−
1
n-1
n−1 次多项式
B
0
(
x
)
B_0(x)
B0(x) 使得
B
0
2
(
x
)
≡
A
(
x
)
(
mod
x
n
)
B_0^2(x)\equiv A(x)\ (\text{mod}\ x^n)
B02(x)≡A(x) (mod xn),现在要求
B
2
(
x
)
≡
A
(
x
)
(
mod
x
2
n
)
B^2(x)\equiv A(x)\ (\text{mod}\ x^{2n})
B2(x)≡A(x) (mod x2n)。上述两式相减,得到
B
(
x
)
−
B
0
(
x
)
≡
0
(
mod
x
n
)
B(x)-B_0(x)\equiv0\ (\text{mod}\ x^n)
B(x)−B0(x)≡0 (mod xn)平方后拆括号可以得到
B
2
(
x
)
+
B
0
2
(
x
)
+
2
B
(
x
)
B
0
(
x
)
≡
0
(
mod
x
2
n
)
B^2(x)+B_0^2(x)+2B(x)B_0(x)\equiv0\ (\text{mod}\ x^{2n})
B2(x)+B02(x)+2B(x)B0(x)≡0 (mod x2n)
A
(
x
)
+
B
0
2
(
x
)
≡
2
B
(
x
)
B
0
(
x
)
(
mod
x
2
n
)
A(x)+B_0^2(x)\equiv2B(x)B_0(x)\ (\text{mod}\ x^{2n})
A(x)+B02(x)≡2B(x)B0(x) (mod x2n)因此
B
(
x
)
≡
A
(
x
)
+
B
0
2
(
x
)
2
B
0
(
x
)
≡
A
B
0
−
1
(
x
)
+
B
0
(
x
)
2
(
mod
x
2
n
)
B(x)\equiv\frac{A(x)+B_0^2(x)}{2B_0(x)}\equiv\frac{AB_0^{-1}(x)+B_0(x)}{2}\ (\text{mod}\ x^{2n})
B(x)≡2B0(x)A(x)+B02(x)≡2AB0−1(x)+B0(x) (mod x2n)
事实上,如果令
f
[
B
(
x
)
]
=
B
2
(
x
)
−
A
(
x
)
f[B(x)]=B^2(x)-A(x)
f[B(x)]=B2(x)−A(x),那么这个结果也可以用牛顿迭代法得到。
通过这个式子,就可以
O
(
n
log
2
n
)
O(n\log_2n)
O(nlog2n) 解决问题。
代码
*注:为了清晰,少了优化,可能超时。
#include<iostream>
#include<cstdio>
#include<cstring>
#define ll long long
using namespace std;
ll mod=998244353,G[2]={332748118,3};
int rev[400001];
ll qp(ll x,int y){
if(y==0) return 1ll;
if(y==1) return x;
ll res=qp(x,y>>1);
(res*=res)%=mod;
if(y&1) (res*=x)%=mod;
return res;
}
ll sub(ll a,ll b){
if(a+b>mod) return a+b-mod;
if(a+b<0) return a+b+mod;
return a+b;
}
void NTT(ll *a,int n,int flag){
ll w_n,w,x,y;
for(int i=1;i<n;i+=1) if(i<rev[i]) swap(a[i],a[rev[i]]);
for(int i=1;i<n;i<<=1){
w_n=qp(G[flag],(mod-1)/(i<<1));
for(int j=0;j<n;j+=(i<<1)){
w=1ll;
for(int k=j;k<i+j;k+=1){
x=a[k]; y=(a[k+i]*w)%mod;
a[k]=sub(x,y);
a[k+i]=sub(x,-y);
(w*=w_n)%=mod;
}
}
}
if(!flag){
ll invn=qp(n,mod-2);
for(int i=0;i<n;i+=1) (a[i]*=invn)%=mod;
}
return;
}
ll gh[400001],fh[400001];
void invF(ll *a,ll *b,int n){
memset(gh,0,sizeof(gh));
memset(fh,0,sizeof(fh));
int m=1;
b[0]=qp(a[0],mod-2);
for(int i=1;i<n;i<<=1){
memset(gh,0,sizeof(gh));
for(int j=1;j<(i<<2);j+=1) rev[j]=((rev[j>>1]>>1)|((j&1)<<m));
for(int j=0;j<i;j+=1) gh[j]=b[j];
for(int j=0;j<(i<<1);j+=1) fh[j]=a[j];
NTT(gh,i<<2,1); NTT(fh,i<<2,1);
for(int j=0;j<(i<<2);j+=1) gh[j]=(((gh[j]*gh[j])%mod)*fh[j])%mod;
NTT(gh,i<<2,0);
for(int j=i;j<(i<<1);j+=1) b[j]=sub(0,-gh[j]);
m+=1;
}
return;
}
ll h[400001],invh[400001];
void sqrtF(ll *a,ll *b,int n){
int m=1;
ll x=qp(2ll,mod-2);
b[0]=1ll;
for(int i=1;i<n;i<<=1){ //倍增
for(int j=0;j<i;j+=1) h[j]=b[j];
for(int j=0;j<(i<<1);j+=1) b[j]=a[j];
invF(h,invh,i<<1); //求逆元
for(int j=1;j<(i<<2);j+=1) rev[j]=((rev[j>>1]>>1)|((j&1)<<m));
NTT(b,i<<2,1); NTT(invh,i<<2,1);
for(int j=0;j<(i<<2);j+=1) (b[j]*=invh[j])%=mod; //乘起来
NTT(b,i<<2,0);
for(int j=0;j<i;j+=1) b[j]=sub(b[j],h[j]); //加上去
for(int j=0;j<(i<<1);j+=1) (b[j]*=x)%=mod; //除以2
m+=1;
}
return;
}
ll f[400001],sqrtf[400001];
int main(){
int n;
scanf("%d",&n);
for(int i=0;i<n;i+=1) scanf("%lld",&f[i]);
sqrtF(f,sqrtf,n);
for(int i=0;i<n;i+=1) printf("%lld\n",sqrtf[i]);
return 0;
}
谢谢观看,记得点赞