题目大意
一开始有 n n n个猎人,第 i i i个猎人有仇恨度 w i w_i wi。每次可以开枪射杀一个活着的猎人。
假设活着的猎人为 i 1 , i 2 , … , i m i_1,i_2,\dots,i_m i1,i2,…,im,则第 i k i_k ik个猎人被射杀的概率是 w i k ∑ j = 1 m w i j \frac{w_{i_k}}{\sum\limits_{j=1}^mw_{i_j}} j=1∑mwijwik。
求 1 1 1号猎人最后一个被射杀的概率。输出答案对 998244353 998244353 998244353取模。
w i > 0 , 1 ≤ ∑ w i ≤ 1 0 5 w_i>0,1\leq \sum w_i\leq 10^5 wi>0,1≤∑wi≤105
题解
转化题意
在不断射杀的过程中,概率的分母会不断改变。所以,我们可以稍微转化一下题意。
令 s u m = ∑ i = 1 n w i sum=\sum\limits_{i=1}^nw_i sum=i=1∑nwi,题意转化为:第 i i i个人被射杀的概率为 w i s u m \dfrac{w_i}{sum} sumwi,已经被射杀的人仍能继续被射杀。如果打中一个活着的人,那么这个人就死去。
为什么呢?因为在每次射杀的时候,每个活人被射杀的概率都他的仇恨度除以所有活人仇恨度的和。这样相对于原来,多了一个射杀死人的过程,但因为活人最终总会被射杀,所以本质上是一样的。
容斥
直接求 1 1 1号最后被射杀的概率比较困难,我们考虑使用容斥。
设在 1 1 1号之后被射杀的人的子集为 S S S的概率为 p ( S ) p(S) p(S),则答案为
a n s = ∑ ( − 1 ) ∣ S ∣ p ( S ) ans=\sum(-1)^{|S|}p(S) ans=∑(−1)∣S∣p(S)
然后,我们考虑如何求 p ( S ) p(S) p(S)。
设 v ( S ) = ∑ i ∈ S w i v(S)=\sum\limits_{i\in S}w_i v(S)=i∈S∑wi。因为在 1 1 1号之后被射杀的人包含 S S S,所以就相当于射杀若干次,每次射杀除了 1 1 1号和集合 S S S之外的人,直到打中 1 1 1号。
p ( S ) = ∑ i = 0 + ∞ ( s u m − w 1 − v ( S ) s u m ) i ⋅ w 1 s u m p(S)=\sum\limits_{i=0}^{+\infty}(\dfrac{sum-w_1-v(S)}{sum})^i\cdot\dfrac{w_1}{sum} p(S)=i=0∑+∞(sumsum−w1−v(S))i⋅sumw1
接下来求 ∑ i = 0 + ∞ ( s u m − w 1 − v ( S ) s u m ) i \sum\limits_{i=0}^{+\infty}(\frac{sum-w_1-v(S)}{sum})^i i=0∑+∞(sumsum−w1−v(S))i。由等比数列求和公式可得
∑ i = 0 + ∞ ( s u m − w 1 − v ( S ) s u m ) i = 1 − ( s u m − w 1 − v ( S ) s u m ) + ∞ 1 − s u m − w 1 − v ( S ) s u m = 1 w 1 + v ( S ) s u m = s u m w 1 + v ( S ) \sum\limits_{i=0}^{+\infty}(\frac{sum-w_1-v(S)}{sum})^i=\dfrac{1-(\frac{sum-w_1-v(S)}{sum})^{+\infty}}{1-\frac{sum-w_1-v(S)}{sum}}=\dfrac{1}{\frac{w_1+v(S)}{sum}}=\dfrac{sum}{w_1+v(S)} i=0∑+∞(sumsum−w1−v(S))i=1−sumsum−w1−v(S)1−(sumsum−w1−v(S))+∞=sumw1+v(S)1=w1+v(S)sum
所以
p ( S ) = s u m w 1 + v ( S ) ⋅ w 1 s u m = w 1 w 1 + v ( S ) p(S)=\dfrac{sum}{w_1+v(S)}\cdot \dfrac{w_1}{sum}=\dfrac{w_1}{w_1+v(S)} p(S)=w1+v(S)sum⋅sumw1=w1+v(S)w1
那么
a n s = ∑ ( − 1 ) ∣ S ∣ w 1 w 1 + v ( S ) ans=\sum(-1)^{|S|}\dfrac{w_1}{w_1+v(S)} ans=∑(−1)∣S∣w1+v(S)w1
生成函数
依题意, 1 ≤ ∑ w i ≤ 1 0 5 1\leq \sum w_i\leq 10^5 1≤∑wi≤105,所以我们可以枚举 v ( S ) v(S) v(S)。
令
g ( i ) = ∑ v ( S ) = i ( − 1 ) ∣ S ∣ g(i)=\sum\limits_{v(S)=i}(-1)^{|S|} g(i)=v(S)=i∑(−1)∣S∣
那么
a n s = ∑ i = 0 s u m g ( i ) ⋅ w 1 w 1 + i ans=\sum\limits_{i=0}^{sum}g(i)\cdot\dfrac{w_1}{w_1+i} ans=i=0∑sumg(i)⋅w1+iw1
于是问题就转化为求 g ( i ) g(i) g(i)了。我们发现, g ( i ) g(i) g(i)其实就是 ∏ i = 2 n ( 1 − x w i ) \prod\limits_{i=2}^n(1-x^{w_i}) i=2∏n(1−xwi)的第 i i i次项的系数。
N T T NTT NTT
接下来,我们要用 N T T NTT NTT来求 ∏ i = 2 n ( 1 − x w i ) \prod\limits_{i=2}^n(1-x^{w_i}) i=2∏n(1−xwi)。
用分治,求 [ l , r ] [l,r] [l,r]的多项式时,先求出 [ l , m i d ] [l,mid] [l,mid]和 [ m i d + 1 , r ] [mid+1,r] [mid+1,r]的多项式,再将两个多项式相乘即可。
我们把求的过程看作 log n \log n logn层,每层的时间复杂度为 O ( s u m log s u m ) O(sum\log sum) O(sumlogsum),所以总时间复杂度为 O ( s u m log s u m log n ) O(sum\log sum\log n) O(sumlogsumlogn)。
总时间复杂度可以看作 O ( n log 2 n ) O(n\log^2 n) O(nlog2n)。
code
#include<bits/stdc++.h>
using namespace std;
long long ans=0,w[100005],f[500005],g[20][500005];
const long long G=3,mod=998244353;
long long mi(long long t,long long v){
if(!v) return 1;
long long re=mi(t,v/2);
re=re*re%mod;
if(v&1) re=re*t%mod;
return re;
}
void ch(long long *a1,int l){
for(int i=1,j=l/2;i<l-1;i++){
if(i<j) swap(a1[i],a1[j]);
int k=l/2;
while(j>=k){
j-=k;k>>=1;
}
j+=k;
}
}
void ntt(long long *a1,int l,int fl){
long long W,wn;
for(int i=2;i<=l;i<<=1){
if(fl==1) wn=mi(G,(mod-1)/i);
else wn=mi(G,mod-1-(mod-1)/i);
for(int j=0;j<l;j+=i){
W=1;
for(int k=j;k<j+i/2;k++,W=W*wn%mod){
long long t=a1[k],u=W*a1[k+i/2]%mod;
a1[k]=(t+u)%mod;
a1[k+i/2]=(t-u+mod)%mod;
}
}
}
if(fl==-1){
long long ny=mi(l,mod-2);
for(int i=0;i<l;i++) a1[i]=a1[i]*ny%mod;
}
}
int solve(int l,int r,long long *a1,int now){
if(l==r){
a1[0]=1;a1[w[l]]=mod-1;
return w[l];
}
int mid=l+r>>1,vt,len=1;
vt=solve(l,mid,a1,now+1)+solve(mid+1,r,g[now+1],now+1);
while(len<=vt) len<<=1;
ch(a1,len);ch(g[now+1],len);
ntt(a1,len,1);ntt(g[now+1],len,1);
for(int i=0;i<len;i++){
a1[i]=a1[i]*g[now+1][i]%mod;
g[now+1][i]=0;
}
ch(a1,len);
ntt(a1,len,-1);
return vt;
}
int main()
{
int n;
scanf("%d",&n);
for(int i=1;i<=n;i++){
scanf("%lld",&w[i]);
}
if(n==1){
printf("1");return 0;
}
int vt=solve(2,n,f,0);
for(int i=0;i<=vt;i++){
ans=(ans+f[i]*w[1]%mod*mi(w[1]+i,mod-2)%mod)%mod;
}
printf("%lld",ans);
return 0;
}