【题目链接】
【思路要点】
- 维护每个单位 i i i 剩余 j j j 点生命值的概率 p i , j p_{i,j} pi,j 即可。
- 对于操作 0 0 0 ,直接更新概率,复杂度为 O ( N ) O(N) O(N) 。
- 对于操作 1 1 1 ,分别考虑每个单位 i i i 被命中的概率,我们需要计算除 i i i 以外剩余单位有 x x x 个的概率 q i , x q_{i,x} qi,x ,注意到加入一个单位对 q ∗ q_* q∗ 的影响相当于乘以一个单项式,可以先求出剩余单位有 x x x 个的概率 q 0 , x q_{0,x} q0,x 再除去对应的单项式得到 q i , x q_{i,x} qi,x ,复杂度为 O ( N 2 ) O(N^2) O(N2) 。
- 时间复杂度 O ( N Q + N 2 C ) O(NQ+N^2C) O(NQ+N2C) 。
【代码】
#include<bits/stdc++.h> using namespace std; const int MAXN = 205; const int MAXM = 105; const int P = 998244353; template <typename T> void read(T &x) { x = 0; int f = 1; char ch = getchar(); for (; !isdigit(ch); ch = getchar()) if (ch == '-') f = -f; for (; isdigit(ch); ch = getchar()) x = x * 10 + ch - '0'; x *= f; } int power(int x, int y) { if (y == 0) return 1; int tmp = power(x, y / 2); if (y % 2 == 0) return 1ll * tmp * tmp % P; else return 1ll * tmp * tmp % P * x % P; } int n, m, q, a[MAXN][MAXM], inver[MAXN]; int main() { read(n), m = 100; for (int i = 1; i <= n; i++) { int x; read(x); a[i][x] = 1; inver[i] = power(i, P - 2); } read(q); while (q--) { int opt; read(opt); if (opt == 0) { int pos, u, v; read(pos), read(u), read(v); int p = 1ll * u * power(v, P - 2) % P; for (int j = 1; j <= m; j++) { a[pos][j - 1] = (a[pos][j - 1] + 1ll * a[pos][j] * p) % P; a[pos][j] = a[pos][j] * (1ll - p + P) % P; } } else { int k; read(k); static int pos[MAXN]; for (int i = 1; i <= k; i++) read(pos[i]); static int p[MAXN], q[MAXN]; for (int i = 1; i <= k; i++) { p[i] = 0; for (int j = 1; j <= m; j++) p[i] = (p[i] + a[pos[i]][j]) % P; q[i] = (1 - p[i] + P) % P; } static int res[MAXN]; memset(res, 0, sizeof(res)); res[0] = 1; for (int i = 1; i <= k; i++) { for (int j = i; j >= 1; j--) res[j] = (1ll * res[j] * q[i] + 1ll * res[j - 1] * p[i]) % P; res[0] = 1ll * res[0] * q[i] % P; } for (int i = 1; i <= k; i++) { if (p[i] == 0) { printf("0 "); continue; } static int tmp[MAXN]; tmp[k + 1] = 0; int inv = power(p[i], P - 2); for (int j = k; j >= 1; j--) tmp[j] = (res[j] - 1ll * tmp[j + 1] * q[i] % P + P) * inv % P; int ans = 0; for (int j = 1; j <= k; j++) ans = (ans + 1ll * tmp[j] * inver[j]) % P; ans = 1ll * ans * p[i] % P; printf("%d ", ans); } printf("\n"); } } for (int i = 1; i <= n; i++) { int ans = 0; for (int j = 1; j <= m; j++) ans = (ans + 1ll * j * a[i][j]) % P; printf("%d ", ans); } printf("\n"); return 0; }