problem
Alice 和 Bob 又在玩游戏。
对于一次游戏,首先 Alice 获得一个长度为 n n n 的序列 a a a,Bob 获得一个长度为 m m m 的序列 b b b。之后他们各从自己的序列里随机取出一个数,分别设为 a x , b y a_x, b_y ax,by,定义这次游戏的 k k k 次价值为 ( a x + b y ) k (a_x + b_y)^k (ax+by)k。
由于他们发现这个游戏实在是太无聊了,所以想让你帮忙计算对于 i = 1 , 2 , ⋯   , t i = 1, 2, \cdots, t i=1,2,⋯,t,一次游戏 i i i 次价值的期望是多少。
由于答案可能很大,只需要求出模 998244353 998244353 998244353 下的结果即可。
数据范围: 1 ≤ n , m ≤ 1 0 5 1≤n,m≤10^5 1≤n,m≤105, 0 ≤ a i , b i < 998244353 0≤a_i,b_i <998244353 0≤ai,bi<998244353, 1 ≤ t ≤ 1 0 5 1≤t≤10^5 1≤t≤105。
solution
不难发现,对于 k ∈ [ 1 , t ] k\in[1,t] k∈[1,t],有:
a n s k = ∑ i = 1 n ∑ j = 1 m ( a i + b j ) k n m ans_k=\frac{\sum_{i=1}^n\sum_{j=1}^m(a_i+b_j)^k}{nm} ansk=nm∑i=1n∑j=1m(ai+bj)k
1 n m \frac 1 {nm} nm1 我们先不管,最后再处理。
把 ( a i + b j ) k (a_i+b_j)^k (ai+bj)k 用二项式定理展开,得到(这里的 a n s k ans_k ansk 表示的只是暂时除去了 1 n m \frac 1 {nm} nm1 的答案):
a n s k = ∑ i = 1 n ∑ j = 1 m ∑ h = 1 k C k h a i    h b j    k − h = ∑ h = 1 k C k h ∑ i = 1 n a i    h ∑ j = 1 m b j    k − h \begin{aligned} ans_k&=\sum_{i=1}^n\sum_{j=1}^m\sum_{h=1}^kC_{k}^ha_i^{\;h}b_{j}^{\;k-h}\\ &=\sum_{h=1}^kC_{k}^h\sum_{i=1}^na_i^{\;h}\sum_{j=1}^mb_{j}^{\;k-h} \end{aligned} ansk=i=1∑nj=1∑mh=1∑kCkhaihbjk−h=h=1∑kCkhi=1∑naihj=1∑mbjk−h
然后把组合数拆开,得到:
a n s k = k ! ∑ h = 1 k ∑ i = 1 n a i    h h ! ∑ j = 1 m b j    k − h ( k − h ) ! ans_k=k!\sum_{h=1}^k\sum_{i=1}^n\frac{a_i^{\;h}}{h!}\sum_{j=1}^m\frac{b_{j}^{\;k-h}}{(k-h)!} ansk=k!h=1∑ki=1∑nh!aihj=1∑m(k−h)!bjk−h
这样有点不太直观,我们令 A ( x ) = ∑ i = 1 n a i    x x ! A(x)=\sum_{i=1}^n\frac{a_i^{\;x}}{x!} A(x)=∑i=1nx!aix, B ( x ) = ∑ i = 1 m b i    x x ! B(x)=\sum_{i=1}^m\frac{b_i^{\;x}}{x!} B(x)=∑i=1mx!bix,那么有:
a n s k = k ! ∑ h = 1 k A ( h ) B ( k − h ) ans_k=k!\sum_{h=1}^kA(h)B(k-h) ansk=k!h=1∑kA(h)B(k−h)
发现这就是个卷积的形式,那么现在的问题主要就是怎么把 A ( x ) A(x) A(x) 和 B ( x ) B(x) B(x) 求出来了。
这很像我之前写的一道题:小 L 的计算题,那么:
A ( x ) = n − x ( ln    ( ∏ j = 1 n ( 1 − a j x ) ) ) ′ x ! A(x)=\frac{n-x(\ln\;(\prod_{j=1}^n(1-a_jx)))'}{x!} A(x)=x!n−x(ln(∏j=1n(1−ajx)))′
B ( x ) B(x) B(x) 类似。具体推导过程就去那篇博客上看吧,这里就不赘述了。
所以说,我们可以多项式 n t t ntt ntt 和多项式求 ln \ln ln 先处理出 A ( x ) A(x) A(x),和 B ( x ) B(x) B(x),然后直接把 A , B A,B A,B 卷起来就可以了。
注意最后要乘个 n m nm nm 的逆元。
时间复杂度 O ( n log 2 n ) O(n\log^2n) O(nlog2n)。
code
#include<cstdio>
#include<vector>
#include<cstring>
#include<algorithm>
#define N 500005
#define P 998244353
using namespace std;
const int g=3;
typedef vector<int> poly;
int n,m,t,pos[N],a[N],b[N],fac[N],ifac[N],inv[N];
int add(int x,int y) {return x+y>=P?x+y-P:x+y;}
int dec(int x,int y) {return x-y< 0?x-y+P:x-y;}
int mul(int x,int y) {return 1ll*x*y%P;}
int power(int a,int b,int ans=1){
for(;b;b>>=1,a=mul(a,a))
if(b&1) ans=mul(ans,a);
return ans;
}
int *w[21],C=20;
void init_w(){
for(int i=1;i<=C;++i)
w[i]=new int[1<<(i-1)];
int now=power(g,(P-1)/(1<<C));
w[C][0]=1;
for(int i=1;i<(1<<(C-1));++i) w[C][i]=mul(w[C][i-1],now);
for(int i=C-1;i;--i)
for(int j=0;j<(1<<(i-1));++j)
w[i][j]=w[i+1][j<<1];
}
void init_fac(){
fac[0]=fac[1]=inv[1]=1;
for(int i=2;i<N;++i) fac[i]=mul(fac[i-1],i);
ifac[N-1]=power(fac[N-1],P-2);
for(int i=N-2;~i;--i) ifac[i]=mul(ifac[i+1],i+1);
for(int i=2;i<N;++i) inv[i]=mul(P-P/i,inv[P%i]);
}
void init_pos(int lim){
for(int i=0;i<lim;++i)
pos[i]=(pos[i>>1]>>1)|((i&1)*(lim>>1));
}
void NTT(poly &f,int lim,int type){
for(int i=0;i<lim;++i)
if(pos[i]>i) swap(f[i],f[pos[i]]);
for(int mid=1,l=1;mid<lim;mid<<=1,++l){
for(int i=0;i<lim;i+=(mid<<1)){
for(int j=0;j<mid;++j){
int p0=f[i+j],p1=mul(f[i+j+mid],w[l][j]);
f[i+j]=add(p0,p1),f[i+j+mid]=dec(p0,p1);
}
}
}
if(type==-1&&(reverse(f.begin()+1,f.begin()+lim),1)){
int inv=power(lim,P-2);
for(int i=0;i<lim;++i) f[i]=mul(f[i],inv);
}
}
poly operator*(poly A,poly B){
int lim=1,len=A.size()+B.size()-2;
while(lim<=len) lim<<=1;init_pos(lim);
A.resize(lim),NTT(A,lim,1);
B.resize(lim),NTT(B,lim,1);
for(int i=0;i<lim;++i) A[i]=mul(A[i],B[i]);
NTT(A,lim,-1),A.resize(len+1);
return A;
}
poly Inv(poly A,int len){
poly C,B(1,power(A[0],P-2));
for(int lim=4;lim<(len<<2);lim<<=1){
init_pos(lim);
C=A,C.resize(lim>>1);
C.resize(lim),NTT(C,lim,1);
B.resize(lim),NTT(B,lim,1);
for(int i=0;i<lim;++i) B[i]=mul(B[i],dec(2,mul(B[i],C[i])));
NTT(B,lim,-1),B.resize(lim>>1);
}
B.resize(len);return B;
}
poly Deriv(poly A){
for(int i=0;i<A.size()-1;++i) A[i]=mul(A[i+1],i+1);
A.pop_back();return A;
}
poly Integ(poly A){
A.push_back(0);
for(int i=A.size()-1;i;--i) A[i]=mul(A[i-1],inv[i]);
A[0]=0;return A;
}
poly Ln(poly A,int len){
A=Integ(Deriv(A)*Inv(A,len)),A.resize(len);
return A;
}
poly f[N];
void build(int root,int l,int r,int *a){
if(l==r){
f[root].clear();
f[root].push_back(1),f[root].push_back(P-a[l]);
return;
}
int mid=(l+r)>>1;
build(root<<1,l,mid,a),build(root<<1|1,mid+1,r,a);
f[root]=f[root<<1]*f[root<<1|1];
}
poly solve(int n,int *a){
build(1,1,n,a);
poly now=Deriv(Ln(f[1],t+1));
now.push_back(0);
for(int i=now.size()-1;i;--i) now[i]=P-now[i-1];
now[0]=n;
for(int i=0;i<now.size();++i) now[i]=mul(now[i],ifac[i]);
return now;
}
int main(){
scanf("%d%d",&n,&m),init_w(),init_fac();
for(int i=1;i<=n;++i) scanf("%d",&a[i]);
for(int i=1;i<=m;++i) scanf("%d",&b[i]);
scanf("%d",&t);
poly F=solve(n,a),G=solve(m,b);
F=F*G;
int temp=mul(power(n,P-2),power(m,P-2));
for(int i=1;i<=t;++i) printf("%d\n",mul(temp,mul(F[i],fac[i])));
return 0;
}