题意
有
n
n
n个初始均处于关闭状态的开关,每个开关有一个权值
p
i
p_i
pi,每一轮会随机选择一个开关并改变其状态,且抽中第
i
i
i个开关的概率为
p
i
∑
p
i
\frac{p_i}{\sum{p_i}}
∑pipi。问变为目标状态期望需要经过多少轮。
n
≤
100
,
∑
p
i
≤
5
∗
1
0
4
n\le100,\sum{p_i}\le5*10^4
n≤100,∑pi≤5∗104
分析
设
F
(
x
)
F(x)
F(x)表示经过
n
n
n轮后恰好变为目标状态的概率的
E
G
F
EGF
EGF,
G
(
x
)
G(x)
G(x)表示经过
n
n
n轮后恰好回到初始状态的概率的
E
G
F
EGF
EGF,目标状态为
t
i
t_i
ti,显然有
F
(
x
)
=
∏
i
=
1
n
e
p
i
x
+
(
−
1
)
t
i
e
−
p
i
x
2
F(x)=\prod_{i=1}^n\frac{e^{p_ix}+(-1)^{t_i}e^{-p_ix}}{2}
F(x)=i=1∏n2epix+(−1)tie−pix
G
(
x
)
=
∏
i
=
1
n
e
p
i
x
+
e
−
p
i
x
2
G(x)=\prod_{i=1}^n\frac{e^{p_ix}+e^{-p_ix}}{2}
G(x)=i=1∏n2epix+e−pix
设
H
(
x
)
H(x)
H(x)为经过
n
n
n轮后第一次变为目标状态的概率的
E
G
F
EGF
EGF,
f
(
x
)
,
g
(
x
)
,
h
(
x
)
f(x),g(x),h(x)
f(x),g(x),h(x)分别表示
F
(
x
)
,
G
(
x
)
,
H
(
x
)
F(x),G(x),H(x)
F(x),G(x),H(x)的
O
G
F
OGF
OGF,那么有
f
(
x
)
=
g
(
x
)
h
(
x
)
f(x)=g(x)h(x)
f(x)=g(x)h(x)
从而得到
h
(
x
)
=
f
(
x
)
g
(
x
)
h(x)=\frac{f(x)}{g(x)}
h(x)=g(x)f(x)
显然我们要求的就是
h
′
(
1
)
h'(1)
h′(1)
考虑如果知道
E
G
F
EGF
EGF怎么求
O
G
F
OGF
OGF,设
F
(
x
)
=
∑
a
e
v
x
F(x)=\sum ae^{vx}
F(x)=∑aevx,则
f
(
x
)
=
∑
a
1
−
v
x
f(x)=\sum \frac{a}{1-vx}
f(x)=∑1−vxa
于是我们可以先通过dp求出
F
(
x
)
F(x)
F(x)和
G
(
x
)
G(x)
G(x),那么
h
′
=
(
f
g
)
′
=
f
′
g
−
f
g
′
g
2
h'=(\frac{f}{g})'=\frac{f'g-fg'}{g^2}
h′=(gf)′=g2f′g−fg′
注意到
f
(
x
)
f(x)
f(x)和
g
(
x
)
g(x)
g(x)的分母中可能含有
(
1
−
x
)
(1-x)
(1−x)的项,可以先把两边同乘
(
1
−
x
)
(1-x)
(1−x),通过推导容易得到
(
1
−
x
1
−
v
x
)
′
∣
x
=
1
=
1
v
−
1
(\frac{1-x}{1-vx})'\bigg|_{x=1}=\frac{1}{v-1}
(1−vx1−x)′∣∣∣∣x=1=v−11
时间复杂度
O
(
n
∑
p
i
)
O(n\sum p_i)
O(n∑pi)
代码
#include<iostream>
#include<cstdio>
#include<cstdlib>
#include<cstring>
#include<algorithm>
typedef long long LL;
const int N=50005;
const int MOD=998244353;
int n,sum,F[N*2],G[N*2],t[N],p[N],tmp[N*2];
int ksm(int x,int y)
{
int ans=1;
while (y)
{
if (y&1) ans=(LL)ans*x%MOD;
x=(LL)x*x%MOD;y>>=1;
}
return ans;
}
void solve()
{
int inv=ksm(2,MOD-2);
F[sum]=1;
for (int i=1;i<=n;i++)
{
int x=inv,y=!t[i]?inv:MOD-inv;
for (int j=0;j<=sum*2;j++)
{
tmp[j]=0;
if (j>=p[i]) (tmp[j]+=(LL)x*F[j-p[i]]%MOD)%=MOD;
if (j+p[i]<=sum*2) (tmp[j]+=(LL)y*F[j+p[i]]%MOD)%=MOD;
}
for (int j=0;j<=sum*2;j++) F[j]=tmp[j];
}
G[sum]=1;
for (int i=1;i<=n;i++)
{
int x=inv,y=inv;
for (int j=0;j<=sum*2;j++)
{
tmp[j]=0;
if (j>=p[i]) (tmp[j]+=(LL)x*G[j-p[i]]%MOD)%=MOD;
if (j+p[i]<=sum*2) (tmp[j]+=(LL)y*G[j+p[i]]%MOD)%=MOD;
}
for (int j=0;j<=sum*2;j++) G[j]=tmp[j];
}
}
int calc()
{
int inv=ksm(sum,MOD-2),f1=0,g1=0,f2=0,g2=0;
for (int i=0;i<=sum*2;i++)
{
int x=(LL)(i-sum)*inv%MOD;
x+=x<0?MOD:0;
if (x==1) (f1+=F[i])%=MOD,(g1+=G[i])%=MOD;
else (f2+=(LL)F[i]*ksm(x-1,MOD-2)%MOD)%=MOD,(g2+=(LL)G[i]*ksm(x-1,MOD-2)%MOD)%=MOD;
}
return (LL)((LL)f2*g1%MOD-(LL)f1*g2%MOD)*ksm((LL)g1*g1%MOD,MOD-2)%MOD;
}
int main()
{
scanf("%d",&n);
for (int i=1;i<=n;i++) scanf("%d",&t[i]);
for (int i=1;i<=n;i++) scanf("%d",&p[i]),sum+=p[i];
solve();
printf("%d\n",(calc()+MOD)%MOD);
return 0;
}