problem
猎人杀是一款风靡一时的游戏 “ “ “狼人杀 ” ” ”的民间版本,他的规则是这样的:
开始有 n n n 个猎人,第 i i i 个猎人有仇恨度 w i w_i wi,每个猎人有一个固定的技能:死亡后必须开一枪,且被射中的人也会死亡。
然而向谁开枪也是有讲究的,假设当前还活着的猎人有 [ i 1 , . . . , i m ] [i_1,...,i_m] [i1,...,im],那么有 w i k ∑ j = 1 m w i j \frac{w_{i_k}}{\sum_{j=1}^mw_{i_j}} ∑j=1mwijwik 的概率是向猎人 i k i_k ik 开枪。
一开始第一枪由你打响,目标的选择方法和猎人一样(即有 w i ∑ j = 1 n w j \frac{w_i}{\sum_{j=1}^nw_j} ∑j=1nwjwi 的概率射中第 i i i 个猎人)。由于开枪导致的连锁反应,所有猎人最终都会死亡,现在 1 1 1 号猎人想知道它是最后一个死的的概率。
答案对 998244353 998244353 998244353 取模。
数据范围: w i > 0 w_i>0 wi>0 且 1 ≤ ∑ i = 1 n w i ≤ 1 0 5 1\le \sum_{i=1}^nw_i\le10^5 1≤∑i=1nwi≤105。
solution
非常妙的一道题。
首先有一个结论,即已经死去的猎人也可以算进概率中,直到下一次打死一个还没死的人。
证明如下:
设每个人的仇恨度之和为 W 1 W_1 W1,已经死去的人的仇恨度之和为 W 2 W_2 W2,那么下一步射死 i i i( i i i 之前都没死)的概率是 P = w i W 1 − W 2 P=\frac{w_i}{W_1-W_2} P=W1−W2wi。
如果把已经死去的猎人也算进概率中,就有:P = W 2 W 1 P + w i W 1 ( W 1 − W 2 ) P = w i P = w i W 1 − W 2 \begin{aligned} P&=\frac{W_2}{W_1}P+\frac{w_i}{W_1}\\ (W_1-W_2)P&=w_i\\ P&=\frac{w_i}{W_1-W_2} \end{aligned} P(W1−W2)PP=W1W2P+W1wi=wi=W1−W2wi
所以这和原问题是等价的。
然后我们考虑容斥算答案。枚举集合 S S S,强制让 S S S 中的人比 1 1 1 后死。令 W = ∑ i = 0 n w i W=\sum_{i=0}^nw_i W=∑i=0nwi, T = ∑ i ∈ S w i T=\sum_{i\in S}w_i T=∑i∈Swi,有:
a n s = ( − 1 ) ∣ S ∣ ∑ i = 0 ∞ ( 1 − T + w 1 W ) i w 1 W ans=(-1)^{|S|}\sum_{i=0}^{\infty}\left(1-\frac{T+w_1}{W}\right)^i\frac{w_1}{W} ans=(−1)∣S∣i=0∑∞(1−WT+w1)iWw1
这个式子的意思是,用前 i i i 枪去打不在 S S S 集合也不是 1 1 1 的人,打完 i i i 枪之后直接打死 1 1 1 的概率。
然后把这个式子化简:
a n s = ( − 1 ) ∣ S ∣ 1 1 − ( 1 − T + w 1 W ) w 1 W = ( − 1 ) ∣ S ∣ w 1 T + w 1 \begin{aligned} ans&=(-1)^{|S|}\frac{1}{1-(1-\frac{T+w_1}{W})}\frac{w_1}{W}\\ &=(-1)^{|S|}\frac{w_1}{T+w_1} \end{aligned} ans=(−1)∣S∣1−(1−WT+w1)1Ww1=(−1)∣S∣T+w1w1
换句话说,我们现在需要枚举集合 S S S,然后计算出这个集合的大小和仇恨度之和。
直接枚举肯定是不行的,考虑用生成函数来优化这一过程。
第 i i i 个数的生成函数是 ( 1 − x w i ) (1-x^{w_i}) (1−xwi),系数为 − 1 -1 −1 的原因是每加进来一个数, ∣ S ∣ |S| ∣S∣ 加一, ( − 1 ) ∣ S ∣ (-1)^{|S|} (−1)∣S∣ 就会变号。
那么就是要求 ∏ i = 2 n ( 1 − x w i ) \prod_{i=2}^n(1-x^{w_i}) ∏i=2n(1−xwi),分治 n t t ntt ntt即可。
时间复杂度 O ( n log 2 n ) O(n\log^2n) O(nlog2n)。
code
#include<cstdio>
#include<vector>
#include<cstring>
#include<algorithm>
#define N 500005
#define P 998244353
using namespace std;
const int g=3;
typedef vector<int> poly;
int n,pos[N],a[N],S[N];
poly f[N];
int add(int x,int y) {return x+y>=P?x+y-P:x+y;}
int dec(int x,int y) {return x-y< 0?x-y+P:x-y;}
int mul(int x,int y) {return 1ll*x*y%P;}
int power(int a,int b,int ans=1){
for(;b;b>>=1,a=mul(a,a))
if(b&1) ans=mul(ans,a);
return ans;
}
int Log=23,*w[24];
void init_w(){
for(int i=1;i<=Log;++i)
w[i]=new int[1<<(i-1)];
int now=power(g,(P-1)/(1<<Log));
w[Log][0]=1;
for(int i=1;i<(1<<(Log-1));++i) w[Log][i]=mul(w[Log][i-1],now);
for(int i=Log-1;i;--i)
for(int j=0;j<(1<<(i-1));++j)
w[i][j]=w[i+1][j<<1];
}
void init_pos(int lim){
for(int i=0;i<lim;++i)
pos[i]=(pos[i>>1]>>1)|((i&1)*(lim>>1));
}
void NTT(poly &f,int lim,int type){
for(int i=0;i<lim;++i)
if(pos[i]>i) swap(f[i],f[pos[i]]);
for(int mid=1,l=1;mid<lim;mid<<=1,++l){
for(int i=0;i<lim;i+=(mid<<1)){
for(int j=0;j<mid;++j){
int p0=f[i+j],p1=mul(f[i+j+mid],w[l][j]);
f[i+j]=add(p0,p1),f[i+j+mid]=dec(p0,p1);
}
}
}
if(type==-1&&(reverse(f.begin()+1,f.begin()+lim),1)){
int inv=power(lim,P-2);
for(int i=0;i<lim;++i) f[i]=mul(f[i],inv);
}
}
poly operator*(poly A,poly B){
int lim=1,len=A.size()+B.size()-2;
while(lim<=len) lim<<=1;init_pos(lim);
A.resize(lim),NTT(A,lim,1);
B.resize(lim),NTT(B,lim,1);
for(int i=0;i<lim;++i) A[i]=mul(A[i],B[i]);
NTT(A,lim,-1),A.resize(len+1);
return A;
}
void build(int root,int l,int r){
if(l==r){
f[root].resize(a[l]+1);
f[root][0]=1,f[root][a[l]]=P-1;
return;
}
int mid=(l+r)>>1;
build(root<<1,l,mid),build(root<<1|1,mid+1,r);
f[root]=f[root<<1]*f[root<<1|1];
}
int main(){
init_w();
scanf("%d",&n);
for(int i=1;i<=n;++i){
scanf("%d",&a[i]),S[i]=S[i-1]+a[i];
}
build(1,2,n);
int ans=0,tot=S[n]-a[1];
for(int i=0;i<=tot;++i)
ans=add(ans,mul(f[1][i],mul(a[1],power(i+a[1],P-2))));
printf("%d",ans);
return 0;
}