题目描述
给定 $n$ 个整数 $w_1,w_2,\cdots, w_n$,以及 $n$ 个取值在 $\{0,1\}$ 的整数 $a_0,a_1,\cdots,a_n$
一共进行 $m$ 轮操作,在每一轮,首先会随机选择一个 $[1,n]$ 的整数,其中选到 $i$ 的概率为 $\frac{w_i}{\sum_{j=1}^{n}w_j}$
如果 $a_i=0$,则 $w_i$ 变为 $w_i-1$,否则 $w_i$ 变为 $w_i+1$
即 $w_i$ 变为 $w_i-(-1)^{a_i}$
求结束后 $w_1,w_2,\cdots,w_n$ 的值的期望
题解
设 $A=\sum_{i=1}^{n}w_i \cdot [a_i=1],B=\sum_{i=1}^{n}w_i \cdot [a_i=0]$
先考虑做 $a_i=1$ 的情况,设 $f_{w}[i][j][k]$ 表示对于权值 $w$,还需要进行 $i$ 轮,其中 $a_i=1$ 的 $w$ 之和为 $j$,且 $a_i=0$ 的 $w$ 之和为 $k$,在进行完后的期望值
首先 $f_{w}[0][j][k]=w$,对于 $i \ge 1$,则有:
$$
f_{w}[i][j][k]= \frac{w}{j+k} f_{w+1}[i-1][j+1][k]+ \frac{j-w}{j+k}f_{w}[i-1][j+1][k]+\frac{k}{j+k}f_{w}[i-1][j][k-1]
$$
由于每次 $j,k$ 只会 $\pm 1$,因此只需要记录分别 $\pm 1$ 了多少次即可,同时如果 $+1$ 了 $x$ 次,且 $-1$ 了 $y$ 次,那么还剩下 $m-x-y$ 轮
也就是可以设 $f'_{w}[i][j]=f_{w}[m-i-j][A+i][B-j]$
可以证明,$f_w[i][j][k]=w \cdot f_1[i][j][k]$,证明暂略
也就是现在只需要求 $f_1[i][j][k]$ 就行了,也就是只需要求 $f_1'[i][j]$,就叫它 $f'[i][j]$ 吧,则:
$$
f'[i][j]=
\begin{cases}
1 & \quad (i+j=m) \\
\frac{2+(A+i-1)}{A+i+B-j} f'[i+1][j]+\frac{B-j}{A+i+B-j}f'[i][j+1] & \quad(i+j<m)
\end{cases}
$$
同理,设 $g'[i][j]$ 表示 $a_i=0$ 的情况,则:
$$
g'[i][j]=
\begin{cases}
1 & \quad(i+j=m) \\
\frac{0+(B-j-1)}{A+i+B-j} g'[i][j+1] + \frac{A+i}{A+i+B-j}g'[i+1][j] & \quad(i+j<m)
\end{cases}
$$
代码
1 #include <bits/stdc++.h> 2 using namespace std; 3 typedef long long ll; 4 5 const int mod = 998244353, N = 2e5 + 10, M = 3010; 6 7 ll pw(ll a, ll b) { 8 ll r = 1; 9 for( ; b ; b >>= 1, a = a * a % mod) { 10 if(b & 1) { 11 r = r * a % mod; 12 } 13 } 14 return r; 15 } 16 17 ll getinv(ll n) { 18 return pw(n, mod - 2); 19 } 20 21 ll s[2], sum, f[M][M], g[M][M]; 22 int a[N], w[N], n, m; 23 void upd(ll &x, ll y) { 24 x = (x + y) % mod; 25 } 26 27 int main() { 28 scanf("%d%d", &n, &m); 29 for(int i = 1 ; i <= n ; ++ i) { 30 scanf("%d", &a[i]); 31 } 32 for(int i = 1 ; i <= n ; ++ i) { 33 scanf("%d", &w[i]); 34 s[a[i]] += w[i]; 35 sum += w[i]; 36 } 37 ll A = s[1], B = s[0]; 38 for(int i = m ; i >= 0 ; -- i) { 39 f[i][m - i] = g[i][m - i] = 1; 40 for(int j = min(m - i - 1ll, s[0]) ; j >= 0 ; -- j) { 41 ll inv = getinv(A + B + i - j); 42 upd(f[i][j], ((A + i + 1) * f[i + 1][j] % mod + (B - j) * f[i][j + 1] % mod) * inv % mod); 43 upd(g[i][j], ((B - j - 1) * g[i][j + 1] % mod + (A + i) * g[i + 1][j] % mod) * inv % mod); 44 } 45 } 46 for(int i = 1 ; i <= n ; ++ i) { 47 printf("%lld\n", (w[i] * (a[i] ? f : g)[0][0] % mod + mod) % mod); 48 } 49 }