Sample Game
题意:
1 ∼ n 1\sim n 1∼n个数,每个数都有一个随机概率 p i p_i pi,进行如下操作:
- 按照概率随机生成一个数。
- 如果生成的数字不小于之前生成的任意一个数字,回到步骤 1 1 1,否则到步骤 3 3 3。
- 如果生成数字的个数为 n n n,那么贡献为 n 2 n^2 n2。
求期望贡献。
思路:
生成的序列肯定是一个非降的形式。我们假设最后数列的长度为 l e n len len,那么可以得到长度大于 i i i的概率 P ( l e n > i ) = ∏ i = 1 n p i t i P(len>i)=\prod\limits_{i=1}^{n}p_i^{t_i} P(len>i)=i=1∏npiti,其中 t i t_i ti表示数字 i i i出现的次数。因为在长度为 i i i的时候,我们并不知道对于 i + 1 i+1 i+1的时候,取的数是会停下来还是会继续取,但是最终的长度一定会比 i i i大,所以这里算的是 l e n > i len>i len>i的概率。
那么我们可以得到 P ( l e n = i ) = P ( l e n > i − 1 ) − P ( l e n > i ) P(len=i)=P(len>i-1)-P(len>i) P(len=i)=P(len>i−1)−P(len>i)。
所以我们贡献的期望就是:
E
(
X
)
=
∑
i
=
1
∞
i
2
P
(
l
e
n
=
i
)
=
∑
i
=
1
∞
i
2
(
P
(
l
e
n
>
i
−
1
)
−
P
(
l
e
n
>
i
)
)
=
1
2
P
(
l
e
n
>
0
)
+
∑
i
=
1
∞
(
(
i
+
1
)
2
−
i
2
)
P
(
l
e
n
>
i
)
=
∑
i
=
0
∞
(
2
i
+
1
)
P
(
l
e
n
>
i
)
\begin{aligned}E(X)&=\sum\limits_{i=1}^{\infty}i^2P(len=i)\\&=\sum\limits_{i=1}^{\infty}i^2(P(len>i-1)-P(len>i))\\&=1^2P(len>0)+\sum\limits_{i=1}^{\infty}((i+1)^2-i^2)P(len>i)\\&=\sum\limits_{i=0}^{\infty}(2i+1)P(len>i)\end{aligned}
E(X)=i=1∑∞i2P(len=i)=i=1∑∞i2(P(len>i−1)−P(len>i))=12P(len>0)+i=1∑∞((i+1)2−i2)P(len>i)=i=0∑∞(2i+1)P(len>i)
对于 P ( l e n > i ) , i ∈ [ 0 , + ∞ ) P(len>i),i\in[0,+\infty) P(len>i),i∈[0,+∞),我们可以将这个数列写成生成函数的形式 f ( x ) = ∑ i = 0 ∞ P ( l e n > i ) x i f(x)=\sum\limits_{i=0}^{\infty}P(len>i)x^i f(x)=i=0∑∞P(len>i)xi。
可得
f
′
(
x
)
=
∑
i
=
0
∞
i
P
(
l
e
n
>
i
)
x
i
−
1
f^{'}(x)=\sum\limits_{i=0}^{\infty}iP(len>i)x^{i-1}
f′(x)=i=0∑∞iP(len>i)xi−1,
所以
E
(
x
)
=
2
f
′
(
1
)
−
f
(
1
)
E(x)=2f^{'}(1)-f(1)
E(x)=2f′(1)−f(1)。
对于第 i i i个数出现 j j j次的概率,我们也可以写成一个生成函数 g i ( x ) = ∑ j = 0 ∞ p i j x j g_i(x)=\sum\limits_{j=0}^{\infty}p_i^jx^j gi(x)=j=0∑∞pijxj。
根据生成函数中相乘的组合意义,我们又可以得到 f ( x ) = ∏ i = 1 n g i ( x ) f(x)=\prod\limits_{i=1}^{n}g_i(x) f(x)=i=1∏ngi(x)。
其中
g
i
(
x
)
g_i(x)
gi(x)可以化简成
1
1
−
p
i
x
\cfrac{1}{1-p_ix}
1−pix1的形式。
推导过程就是:
∵
g
i
(
x
)
=
∑
j
=
0
∞
p
i
j
x
j
\because g_i(x)=\sum\limits_{j=0}^{\infty}p_i^jx^j
∵gi(x)=j=0∑∞pijxj
∴
p
i
x
g
i
(
x
)
=
∑
j
=
1
∞
p
i
j
x
j
\therefore p_ixg_i(x)=\sum\limits_{j=1}^{\infty}p_i^jx^j
∴pixgi(x)=j=1∑∞pijxj
∴
(
1
−
p
i
x
)
g
i
(
x
)
=
1
\therefore (1-p_ix)g_i(x)=1
∴(1−pix)gi(x)=1
∴
g
i
(
x
)
=
1
1
−
p
i
x
\therefore g_i(x)=\cfrac{1}{1-p_ix}
∴gi(x)=1−pix1
所以可以得到:
f
′
(
x
)
=
∑
i
=
1
n
(
g
i
′
(
x
)
∏
j
=
1
,
j
≠
i
n
g
j
(
x
)
)
=
∑
i
=
1
n
(
p
i
(
1
−
p
i
)
2
∏
j
=
1
,
j
≠
i
n
g
j
(
x
)
)
=
∑
i
=
1
n
(
g
i
(
x
)
p
i
1
−
p
i
∏
j
=
1
,
j
≠
i
n
g
j
(
x
)
)
=
∑
i
=
1
n
(
p
i
1
−
p
i
∏
j
=
1
n
g
j
(
x
)
)
=
f
(
x
)
∑
i
=
1
n
p
i
1
−
p
i
\begin{aligned}f^{'}(x)&=\sum\limits_{i=1}^{n}(g_i^{'}(x)\prod\limits_{j=1,j\neq i}^{n}g_j(x))\\&=\sum\limits_{i=1}^{n}(\cfrac{p_i}{(1-p_i)^2}\prod\limits_{j=1,j\neq i}^{n}g_j(x))\\&=\sum\limits_{i=1}^{n}(g_i(x)\cfrac{p_i}{1-p_i}\prod\limits_{j=1,j\neq i}^{n}g_j(x))\\&=\sum\limits_{i=1}^{n}(\cfrac{p_i}{1-p_i}\prod\limits_{j=1}^{n}g_j(x))\\&=f(x)\sum\limits_{i=1}^{n}\cfrac{p_i}{1-p_i}\end{aligned}
f′(x)=i=1∑n(gi′(x)j=1,j=i∏ngj(x))=i=1∑n((1−pi)2pij=1,j=i∏ngj(x))=i=1∑n(gi(x)1−pipij=1,j=i∏ngj(x))=i=1∑n(1−pipij=1∏ngj(x))=f(x)i=1∑n1−pipi
所以最后可以得到: f ( 1 ) = ∏ i = 1 n 1 1 − p i , f ′ ( 1 ) = f ( 1 ) ∑ i = 1 n p i 1 − p i f(1)=\prod\limits_{i=1}^{n}\cfrac{1}{1-p_i},f^{'}(1)=f(1)\sum\limits_{i=1}^{n}\cfrac{p_i}{1-p_i} f(1)=i=1∏n1−pi1,f′(1)=f(1)i=1∑n1−pipi
代码:
#include<bits/stdc++.h>
#define fi first
#define se second
#define int long long
#define mp make_pair
#define pb push_back
#define ls x<<1
#define rs x<<1|1
#define lson x<<1,l,mid
#define rson x<<1|1,mid+1,r
#define pii pair<int,int>
#define all(x) x.begin(),x.end()
#define cl(x,y) memset(x,y,sizeof(x))
#define nxtp(a,n) next_permutation(a+1,a+n+1)
#define mem(x,y,n) memset(x,y,sizeof(int)*(n+5))
const int N=1e6+10;
const int mod=998244353;
const int inf=0x3f3f3f3f;
const double eps=1e-8;
const double pi=acos(-1);
const double INF=1e18;
using namespace std;
int p[N];
int qpow(int a,int b)
{
int ans=1;
while(b)
{
if(b&1)
ans=ans*a%mod;
a=a*a%mod;
b>>=1;
}
return ans;
}
signed main()
{
ios::sync_with_stdio(false);
cin.tie(0);cout.tie(0);
int n,i,sum=0;
cin>>n;
for(i=1;i<=n;i++)
{
cin>>p[i];
sum+=p[i];
}
sum=qpow(sum,mod-2);
for(i=1;i<=n;i++)
p[i]=p[i]*sum%mod;
int f1=1,f2=0;
for(i=1;i<=n;i++)
f1=(f1*(1*qpow((1-p[i]+mod)%mod,mod-2)))%mod;
for(i=1;i<=n;i++)
f2=(f2+(p[i]*qpow((1-p[i]+mod)%mod,mod-2)))%mod;
f2=2*f2*f1%mod;
cout<<(f1+f2)%mod<<endl;
return 0;
}