problem
现有一个长度为 n ( 1 ≤ n ≤ 2 × 1 0 5 ) n(1\le n\le 2\times10^5) n(1≤n≤2×105) 的非负整数数组 { a i } \{a_i\} {ai}。小 L 定义了一种神奇变换:
f k = ∑ i = 1 n a i    k      (   m o d     998244353 ) f_k=\sum_{i=1}^na_i^{\;k}\;\;(\bmod \;998244353) fk=i=1∑naik(mod998244353)
小 L 计划用变换生成的序列 f f f 做一些有趣的事情,但是他并不擅长算乘法,所以来找你帮忙,希望你能帮他尽快计算出 f 1 f_1 f1 ~ f n f_n fn。
solution
考虑先构造 f f f 的 OGF:
f ( x ) = ∑ i = 0 ∞ ( ∑ j = 1 n a j    i ) x i = ∑ j = 1 n ∑ i = 0 ∞ ( a j x ) i = ∑ j = 1 n 1 1 − a j x \begin{aligned} f(x)&=\sum_{i=0}^{\infty} (\sum_{j=1}^{n}a_j^{\;i})x^i\\ &=\sum_{j=1}^n\sum_{i=0}^{\infty}(a_jx)^i\\ &=\sum_{j=1}^n\frac{1}{1-a_jx}\\ \end{aligned} f(x)=i=0∑∞(j=1∑naji)xi=j=1∑ni=0∑∞(ajx)i=j=1∑n1−ajx1
这个式子好像也不好求,继续化简:
f ( x ) = ∑ j = 1 n ( 1 − − a j x 1 − a j x ) = n − x ∑ j = 1 n − a j 1 − a j x \begin{aligned} f(x)&=\sum_{j=1}^n(1-\frac{-a_jx}{1-a_jx})\\ &=n-x\sum_{j=1}^n{\frac{-a_j}{1-a_jx}}\\ \end{aligned} f(x)=j=1∑n(1−1−ajx−ajx)=n−xj=1∑n1−ajx−aj
由于 − a j 1 − a j x = ( ln    ( 1 − a j x ) ) ′ \frac{-a_j}{1-a_jx}=(\ln \;(1-a_jx))' 1−ajx−aj=(ln(1−ajx))′,那么直接代进去,得:
f ( x ) = n − x ∑ j = 1 n ( ln    ( 1 − a j x ) ) ′ f(x)=n-x\sum_{j=1}^n(\ln \;(1-a_jx))' f(x)=n−xj=1∑n(ln(1−ajx))′
又由于有法则 ( u + v ) ′ = u ′ + v ′ (u+v)'=u'+v' (u+v)′=u′+v′,所以 ∑ \sum ∑ 符号可以直接丢进去,即:
f ( x ) = n − x ( ∑ j = 1 n ln    ( 1 − a j x ) ) ′ f(x)=n-x(\sum_{j=1}^n\ln \;(1-a_jx))' f(x)=n−x(j=1∑nln(1−ajx))′
然后把 ln \ln ln 提前,得到:
f ( x ) = n − x ( ln    ( ∏ j = 1 n    ( 1 − a j x ) ) ) ′ f(x)=n-x(\ln\;(\prod_{j=1}^n \;(1-a_jx)))' f(x)=n−x(ln(j=1∏n(1−ajx)))′
于是用分治 n t t ntt ntt + 多项式求 ln \ln ln 即可。
code
#include<cstdio>
#include<vector>
#include<cstring>
#include<algorithm>
#define N 1000005
#define P 998244353
using namespace std;
const int g=3;
typedef vector<int> poly;
int n,T,a[N],pos[N],inv[N];
poly f[N];
int add(int x,int y) {return x+y>=P?x+y-P:x+y;}
int dec(int x,int y) {return x-y< 0?x-y+P:x-y;}
int mul(int x,int y) {return 1ll*x*y%P;}
int power(int a,int b,int ans=1){
for(;b;b>>=1,a=mul(a,a))
if(b&1) ans=mul(ans,a);
return ans;
}
int *w[22],C=21;
void prework(){
for(int i=1;i<=C;++i)
w[i]=new int[1<<(i-1)];
int now=power(g,(P-1)/(1<<C));
w[C][0]=1;
for(int i=1;i<(1<<(C-1));++i) w[C][i]=mul(w[C][i-1],now);
for(int i=C-1;i;--i)
for(int j=0;j<(1<<(i-1));++j)
w[i][j]=w[i+1][j<<1];
}
void init(int lim){
for(int i=0;i<lim;++i)
pos[i]=(pos[i>>1]>>1)|((i&1)*(lim>>1));
}
void NTT(poly &f,int lim,int type){
for(int i=0;i<lim;++i)
if(pos[i]>i) swap(f[i],f[pos[i]]);
for(int mid=1,l=1;mid<lim;mid<<=1,++l){
for(int i=0;i<lim;i+=(mid<<1)){
for(int j=0;j<mid;++j){
int p0=f[i+j],p1=mul(f[i+j+mid],w[l][j]);
f[i+j]=add(p0,p1),f[i+j+mid]=dec(p0,p1);
}
}
}
if(type==-1&&(reverse(f.begin()+1,f.begin()+lim),1)){
int inv=power(lim,P-2);
for(int i=0;i<lim;++i) f[i]=mul(f[i],inv);
}
}
poly operator*(poly A,poly B){
int len=A.size()+B.size()-2,lim=1;
while(lim<=len) lim<<=1;init(lim);
A.resize(lim),NTT(A,lim,1);
B.resize(lim),NTT(B,lim,1);
for(int i=0;i<lim;++i) A[i]=mul(A[i],B[i]);
NTT(A,lim,-1),A.resize(len+1);
return A;
}
poly Inv(poly A,int len){
poly C,B(1,power(A[0],P-2));
for(int lim=4;lim<(len<<2);lim<<=1){
init(lim);
C=A,C.resize(lim>>1);
C.resize(lim),NTT(C,lim,1);
B.resize(lim),NTT(B,lim,1);
for(int i=0;i<lim;++i) B[i]=mul(B[i],dec(2,mul(B[i],C[i])));
NTT(B,lim,-1),B.resize(lim>>1);
}
B.resize(len);return B;
}
poly Deriv(poly A){
for(int i=0;i<A.size()-1;++i) A[i]=mul(A[i+1],i+1);
A.pop_back();return A;
}
poly Integ(poly A){
A.push_back(0);
for(int i=A.size()-1;i;--i) A[i]=mul(A[i-1],inv[i]);
A[0]=0;return A;
}
poly Ln(poly A,int len){
A=Integ(Deriv(A)*Inv(A,n)),A.resize(len);
return A;
}
void solve(int root,int l,int r){
if(l==r){
f[root].clear();
f[root].push_back(1),f[root].push_back(P-a[l]);
return;
}
int mid=(l+r)>>1;
solve(root<<1,l,mid),solve(root<<1|1,mid+1,r);
f[root]=f[root<<1]*f[root<<1|1];
}
int main(){
prework();
scanf("%d",&T);
inv[0]=inv[1]=1;
for(int i=2;i<N;++i) inv[i]=mul(P-P/i,inv[P%i]);
while(T--){
scanf("%d",&n);
for(int i=1;i<=n;++i) scanf("%d",&a[i]),a[i]%=P;
solve(1,1,n);
poly now=Deriv(Ln(f[1],n+1));
now.push_back(0);
for(int i=now.size()-1;i;--i) now[i]=dec(P,now[i-1]);
now[0]=n;
int ans=0;
for(int i=1;i<=n;++i) ans^=now[i];
printf("%d\n",ans);
}
return 0;
}