Description
给出n个数
a
1
.
.
.
a
n
a_1...a_n
a1...an
求它们的
1
1
1~
n
n
n次方和
n
≤
200000
,
a
≤
1
0
9
n\leq 200000,a\leq 10^9
n≤200000,a≤109
20组数据,n总和不超过400000
Solution
一种做法是利用牛顿恒等式
考虑多项式
F
(
x
)
=
a
n
x
n
+
a
n
−
1
x
n
−
1
+
⋯
+
a
1
x
+
a
0
F(x)=a_nx^n+a_{n-1}x^{n-1}+\cdots+a_1x+a_0
F(x)=anxn+an−1xn−1+⋯+a1x+a0
令其所有根为
x
1
,
x
2
,
.
.
.
,
x
n
x_1,x_2,...,x_n
x1,x2,...,xn
设
b
i
=
a
n
−
i
b_i=a_{n-i}
bi=an−i(反位),
S
k
=
∑
i
=
1
n
x
i
k
S_k=\sum\limits_{i=1}^{n}x_i^k
Sk=i=1∑nxik
恒有
∑
i
=
1
k
S
i
b
k
−
i
+
k
×
b
k
=
0
,
k
∈
N
∗
\sum\limits_{i=1}^{k}S_ib_{k-i}+k\times b_k=0,k\in \N^*
i=1∑kSibk−i+k×bk=0,k∈N∗
证明就不说了,网上有不少的论文有讲。
如果我们将
S
,
b
,
k
∗
b
k
S,b,k*b_k
S,b,k∗bk均用一般生成函数表示
F
(
x
)
=
∑
i
>
0
S
i
x
i
,
B
(
x
)
=
∑
i
≥
0
b
i
x
i
,
G
(
x
)
=
∑
i
≥
0
i
b
i
x
i
F(x)=\sum\limits_{i>0}S_ix^i,B(x)=\sum\limits_{i\geq 0}b_ix^i,G(x)=\sum\limits_{i\geq 0}ib_ix^i
F(x)=i>0∑Sixi,B(x)=i≥0∑bixi,G(x)=i≥0∑ibixi
那么有
F
(
x
)
B
(
x
)
+
G
(
x
)
=
0
F(x)B(x)+G(x)=0
F(x)B(x)+G(x)=0
F
(
x
)
=
−
G
(
x
)
B
(
x
)
F(x)={-G(x)\over B(x)}
F(x)=B(x)−G(x)
构造多项式满足根为给定值是容易的,即 ( x − a 1 ) ( x − a 2 ) . . . ( x − a n ) (x-a_1)(x-a_2)...(x-a_n) (x−a1)(x−a2)...(x−an),分治NTT求出系数,然后多项式求逆即可。
时间复杂度 O ( n log 2 n ) O(n\log^2 n) O(nlog2n)
此外我自己还想了一种不需要牛顿恒等式的做法
设
S
i
=
∑
j
=
1
n
a
j
i
S_i=\sum\limits_{j=1}^{n}a_j^i
Si=j=1∑naji
考虑S的一般生成函数
方便起见,我们给S(x)的x^0项也加上,最后的值不管他
S ( x ) = ∑ i ≥ 0 ∑ j = 1 n a j i x i S(x)=\sum\limits_{i\geq 0}\sum\limits_{j=1}^{n}a_j^ix^i S(x)=i≥0∑j=1∑najixi
=
∑
j
=
1
n
∑
i
≥
0
a
j
i
x
i
=\sum\limits_{j=1}^{n}\sum\limits_{i\geq 0}a_j^ix^i
=j=1∑ni≥0∑ajixi
=
∑
j
=
1
n
1
1
−
a
j
x
=\sum\limits_{j=1}^{n}{1\over 1-a_jx}
=j=1∑n1−ajx1
设
P
(
x
)
=
∏
j
=
1
n
(
1
−
a
j
x
)
P(x)=\prod\limits_{j=1}^{n}(1-a_jx)
P(x)=j=1∏n(1−ajx)
那么
S
(
x
)
=
∑
j
=
1
n
∑
k
̸
=
j
(
1
−
a
k
x
)
P
(
x
)
S(x)={\sum\limits_{j=1}^{n}\sum\limits_{k\not=j}(1-a_kx)\over P(x)}
S(x)=P(x)j=1∑nk̸=j∑(1−akx)
P
(
x
)
P(x)
P(x)可以分治NTT计算,但是上面挖掉某一项乘积的和该怎么计算呢
这就有一个很Tricky的方法,之前有一道题[JZOJ5998]【WC2019模拟2019.1.14】操作采用了类似的做法
分治NTT时,对于分治区间分别记录总的乘积的挖掉一项的乘积的和,那么左右相乘的时候一边的总乘积乘上另一边挖掉一项的乘积和就是总的挖掉一项的乘积和。
这样复杂度也是 O ( n log 2 n ) O(n\log ^2n) O(nlog2n)的。
Code
#include <bits/stdc++.h>
#define fo(i,a,b) for(int i=a;i<=b;++i)
#define fod(i,a,b) for(int i=a;i>=b;--i)
#define LL long long
#define N 200005
#define M 524288
#define mo 998244353
using namespace std;
LL a[M+1],b[M+1],c[M+1],u1[M+1],u2[M+1],u3[M+1],wi[M+1],wg[M+1],ny;
int bit[M+1],n,t,a1[M],l2[M+1],cf[21],st[N],le[N];
LL ksm(LL k,LL n)
{
LL s=1;
for(;n;n>>=1,k=k*k%mo) if(n&1) s=s*k%mo;
return s;
}
void prp(int num)
{
fo(i,0,num)
{
wi[i]=wg[i*(M/num)];
bit[i]=(bit[i>>1]>>1)|((i&1)<<(l2[num]-1));
}
ny=ksm(num,mo-2);
}
void NTT(LL *a,bool pd,int num)
{
fo(i,0,num-1) if(bit[i]>i) swap(a[i],a[bit[i]]);
for(int half=1,m=2,lim=num>>1;m<=num;half=m,m<<=1,lim>>=1)
{
LL wn=(!pd)?wi[lim]:wi[num-lim],v;
for(int j=0;j<num;j+=m)
{
LL w=1,v;
for(int i=0;i<half;++i,w=w*wn%mo)
{
v=w*a[i+j+half]%mo;
a[i+j+half]=(a[i+j]-v+mo)%mo;
a[i+j]=(a[i+j]+v)%mo;
}
}
}
if(pd) fo(i,0,num-1) a[i]=a[i]*ny%mo;
}
void doit(int l,int r)
{
if(l==r) return;
int mid=(l+r)>>1;
doit(l,mid),doit(mid+1,r);
int num=cf[l2[le[l]+le[mid+1]-1]];
fo(i,0,num-1) u1[i]=u2[i]=0;
fo(i,0,le[l]-1) u1[i]=a[st[l]+i];
fo(i,0,le[mid+1]-1) u2[i]=a[st[mid+1]+i];
prp(num);
NTT(u1,0,num),NTT(u2,0,num);
fo(i,0,num-1) u1[i]=u1[i]*u2[i]%mo;
NTT(u1,1,num);
le[l]+=le[mid+1]-1;
fo(i,0,le[l]-1) a[st[l]+i]=u1[i];
}
void getinv(int n,LL *a,LL *b)
{
fo(i,0,n-1) b[i]=0;
b[0]=ksm(a[0],mo-2);
for(int m=1,t=2,num=4;m<n;m=t,t=num,num<<=1)
{
prp(num);
fo(i,0,num-1) u2[i]=u1[i]=0;
fo(i,0,m-1) u2[i]=b[i];
fo(i,0,t-1) u1[i]=a[i];
NTT(u1,0,num),NTT(u2,0,num);
fo(i,0,num-1) u1[i]=u1[i]*u2[i]%mo*u2[i]%mo;
NTT(u1,1,num);
fo(i,0,t-1) b[i]=((LL)2*b[i]-u1[i]+mo)%mo;
}
}
int main()
{
cf[0]=1;
fo(i,1,19) l2[cf[i]=cf[i-1]<<1]=i;
fod(i,M,2) if(!l2[i]) l2[i]=l2[i+1];
wg[0]=1,wg[1]=ksm(3,(mo-1)/M);
fo(i,2,M) wg[i]=wg[i-1]*wg[1]%mo;
cin>>t;
while(t--)
{
memset(c,0,sizeof(c));
memset(b,0,sizeof(b));
scanf("%d",&n);
int len=-1;
fo(i,1,n)
{
int x;
scanf("%d",&x);
a[st[i]=++len]=-x;
a[++len]=1;
le[i]=2;
}
doit(1,n);
fo(i,0,le[1]-1) b[n-i]=a[i];
getinv(n+1,b,c);
fo(i,0,n) b[i]=-b[i]*(LL)i%mo;
int num=cf[l2[n]+1];
prp(num);
NTT(b,0,num),NTT(c,0,num);
fo(i,0,num-1) b[i]=b[i]*c[i];
NTT(b,1,num);
LL ans=0;
fo(i,1,n) ans^=(b[i]+mo)%mo;
printf("%lld\n",ans);
}
}