题目大意:
给出一个随机数生成器,每次生成 [ 1 , n ] [1,n] [1,n],当且仅当生成的 x x x为不小于前面生成的所有值才能继续生成,否则结束。求生成数字的个数平方的期望值。
解题思路:(思路来源:Kur1su)
- 要求生成个数平方的期望值,其实先最好先求个数的期望值,再转换到平方
- 设 f [ x ] f[x] f[x]表示当前为x,之后所能得到的数字个数的期望值
- f [ x ] = 1 ⋅ ∑ i = 1 x − 1 p [ i ] + ( 1 + f [ x ] ) ⋅ p [ x ] + ∑ i = x + 1 n ( 1 + f [ i ] ) ∗ p [ i ] f[x]=1\cdot \sum_{i=1}^{x-1}p[i]+(1+f[x])\cdot p[x] + \sum_{i=x+1}^{n}(1+f[i])*p[i] f[x]=1⋅∑i=1x−1p[i]+(1+f[x])⋅p[x]+∑i=x+1n(1+f[i])∗p[i]
- 三部分分别为:抽到比x小的,抽到x,抽到比x大的
- 化简得: f [ x ] = 1 + ∑ i = x + 1 n p [ i ] f [ i ] 1 − p [ x ] f[x]=\frac{1+\sum_{i=x+1}^{n}p[i]f[i]}{1-p[x]} f[x]=1−p[x]1+∑i=x+1np[i]f[i]
- 再设 g [ x ] g[x] g[x]表示当前为 x x x(即 E ( x 2 ) E(x^2) E(x2)),而 E ( ( x + 1 ) 2 ) = E ( x 2 ) + 2 E ( x ) + 1 E((x+1)^2)=E(x^2)+2E(x)+1 E((x+1)2)=E(x2)+2E(x)+1
- 所以类似得可以列出: g [ x ] = 1 ⋅ ∑ i = 1 x − 1 p [ i ] + ( g [ x ] + 1 + 2 f [ x ] ) p [ x ] + ∑ i = x + 1 n ( g [ i ] + 1 + 2 f [ i ] ) p [ i ] g[x]=1\cdot \sum_{i=1}^{x-1}p[i]+(g[x]+1+2f[x])p[x]+\sum_{i=x+1}^{n}(g[i]+1+2f[i])p[i] g[x]=1⋅∑i=1x−1p[i]+(g[x]+1+2f[x])p[x]+∑i=x+1n(g[i]+1+2f[i])p[i]
- 化简得: g [ x ] = 1 + 2 f [ x ] p [ x ] + ∑ i = x + 1 n p [ i ] ( g [ i ] + 2 f [ i ] ) 1 − p [ x ] g[x]=\frac{1+2f[x]p[x]+\sum_{i=x+1}^{n}p[i](g[i]+2f[i])}{1-p[x]} g[x]=1−p[x]1+2f[x]p[x]+∑i=x+1np[i](g[i]+2f[i])
- 从后往前dp即可
- a n s = ∑ i = 1 n p [ i ] ( 1 + 2 f [ i ] + g [ i ] ) ans=\sum_{i=1}^{n}p[i](1+2f[i]+g[i]) ans=∑i=1np[i](1+2f[i]+g[i])
AC代码:
#include <bits/stdc++.h>
#define ft first
#define sd second
#define IOS ios::sync_with_stdio(false), cin.tie(0), cout.tie(0) //不能跟puts混用
#define seteps(N) fixed << setprecision(N)
#define endl "\n"
const int maxn = 1e3;
using namespace std;
typedef long long ll;
typedef double db;
typedef pair<int, int> pii;
const ll mod = 998244353;
int n;
ll f[maxn], g[maxn], ans;
ll mol[maxn], den, p[maxn], _1p[maxn];
ll qpow(ll a, ll b) {
ll res = 1;
while (b) {
if (b & 1) res = res * a % mod;
a = a * a % mod;
b >>= 1;
}
return res;
}
int main() {
cin >> n;
for (int i = 1; i <= n; i++) cin >> mol[i], den += mol[i];
for (int i = 1; i <= n; i++) p[i] = mol[i] * qpow(den, mod - 2) % mod, _1p[i] = (den - mol[i]) * qpow(den, mod - 2) % mod;
//可以用前缀和优化到O(n)
for (int i = n; i >= 1; i--) {
ll res = 1;
for (int j = i + 1; j <= n; j++) res = (res + p[j] * f[j] % mod) % mod;
f[i] = res * qpow(_1p[i], mod - 2) % mod;
}
// for (int i = 1; i <= n; i++) cout << f[i] << " \n"[i == n];
for (int i = n; i >= 1; i--) {
ll res = (1 + 2 * f[i] * p[i] % mod) % mod;
for (int j = i + 1; j <= n; j++) res = (res + p[j] * (g[j] + 2 * f[j]) % mod) % mod;
g[i] = res * qpow(_1p[i], mod - 2) % mod;
}
for (int i = 1; i <= n; i++)
ans = (ans + p[i] * (1 + 2 * f[i] + g[i]) % mod) % mod;
cout << ans << endl;
return 0;
}