测试地址:猎人杀
做法:本题需要用到容斥+级数+分治NTT。
要求
1
1
号最后一个被射杀,其实就是要求所有人都不能在号后被射杀。这种要求全部条件满足求方案数/概率的情况,就要考虑容斥,即枚举一个集合
S
S
,计算强制这个人在
1
1
号后被射杀的概率,那么答案就等于:
ans=∑S(−1)|S|p(S)
a
n
s
=
∑
S
(
−
1
)
|
S
|
p
(
S
)
可是由于游戏的每一步中,概率的分母都不同,
p(S)
p
(
S
)
很难计算,怎么办呢?我们需要对游戏做出一些转化:一个人被射杀后,他仍然参与概率的计算,但如果射中了已经被射杀的人,就再射一次,显然这和原来的游戏是等价的。这样的话,令
sum(S)=∑i∈Swi,W=∑ni=1wi
s
u
m
(
S
)
=
∑
i
∈
S
w
i
,
W
=
∑
i
=
1
n
w
i
,我们有:
p(S)=∑∞i=0(1−w1+sum(S)W)iw1W
p
(
S
)
=
∑
i
=
0
∞
(
1
−
w
1
+
s
u
m
(
S
)
W
)
i
w
1
W
把
w1W
w
1
W
提出来后,剩下的和式是一个无穷级数,因为
0<1−w1+sum(S)W<1
0
<
1
−
w
1
+
s
u
m
(
S
)
W
<
1
,所以这个级数是收敛的,那么它就等于前缀和数列的极限。我们有公式:
∑∞i=0xi=11−x
∑
i
=
0
∞
x
i
=
1
1
−
x
所以有:
p(S)=w1W⋅Ww1+sum(S)=w1w1+sum(S)
p
(
S
)
=
w
1
W
⋅
W
w
1
+
s
u
m
(
S
)
=
w
1
w
1
+
s
u
m
(
S
)
于是有:
ans=w1∑S(−1)|S|w1+sum(S)
a
n
s
=
w
1
∑
S
(
−
1
)
|
S
|
w
1
+
s
u
m
(
S
)
虽然我们极大地简化了所求的式子,但是这个还是不太好求。这时我们注意到一个条件:
∑ni=1wi≤105
∑
i
=
1
n
w
i
≤
10
5
,这启发我们分开计算每种分母的贡献。于是我们构造一个生成函数,其中
xi
x
i
项的系数就表示分母为
i
i
的数对答案贡献的分子,我们怎么算出这个生成函数呢?注意到这就等于,于是分治NTT求出后面的部分即可。这里的分治NTT就是单纯的分治+NTT,而不是CDQ分治+NTT。于是我们就解决了这一题,时间复杂度为
O(WlogWlogn)
O
(
W
log
W
log
n
)
。
以下是本人代码:
#include <bits/stdc++.h>
using namespace std;
typedef long long ll;
const ll mod=998244353;
const ll g=3;
int n,sum,rev[200010],cnt=0,siz[30];
ll w[200010],A[30][200010];
ll power(ll a,ll b)
{
ll s=1,ss=a;
b=(b+mod-1)%(mod-1);
while(b)
{
if (b&1) s=s*ss%mod;
ss=ss*ss%mod;b>>=1;
}
return s;
}
ll NTT(ll *a,int n,int type)
{
for(int i=0;i<n;i++)
if (i<rev[i]) swap(a[i],a[rev[i]]);
for(int mid=1;mid<n;mid<<=1)
{
ll W=power(g,type*(mod-1)/(mid<<1));
for(int l=0,G=(mid<<1);l<n;l+=G)
{
ll w=1;
for(int k=0;k<mid;k++,w=w*W%mod)
{
ll x=a[l+k],y=w*a[l+mid+k]%mod;
a[l+k]=(x+y)%mod;
a[l+mid+k]=(x-y+mod)%mod;
}
}
}
if (type==-1)
{
ll inv=power(n,mod-2);
for(int i=0;i<n;i++)
a[i]=a[i]*inv%mod;
}
}
void solve(int l,int r)
{
if (l==r)
{
cnt++;
A[cnt][0]=1,A[cnt][w[l]]=mod-1;
siz[cnt]=w[l];
for(int i=1;i<w[l];i++)
A[cnt][i]=0;
return;
}
int mid=(l+r)>>1;
solve(l,mid);
solve(mid+1,r);
int bit=0,x=1,a=cnt-1,b=cnt,tot=siz[a]+siz[b];
while(x<=tot) bit++,x<<=1;
rev[0]=0;
for(int i=1;i<x;i++)
rev[i]=(rev[i>>1]>>1)|((i&1)<<(bit-1));
for(int i=siz[a]+1;i<x;i++)
A[a][i]=0;
for(int i=siz[b]+1;i<x;i++)
A[b][i]=0;
NTT(A[a],x,1),NTT(A[b],x,1);
for(int i=0;i<x;i++)
A[a][i]=A[a][i]*A[b][i]%mod;
NTT(A[a],x,-1);
cnt--;
siz[cnt]=tot;
}
int main()
{
scanf("%d",&n);
for(int i=1;i<=n;i++)
{
scanf("%lld",&w[i]);
sum+=w[i];
}
if (n==1) printf("1");
else
{
solve(2,n);
ll ans=0;
for(int i=0;i<=sum;i++)
ans=(ans+A[1][i]*power(w[1]+i,mod-2))%mod;
printf("%lld",ans*w[1]%mod);
}
return 0;
}