【题目链接】
【思路要点】
- 考虑 k = 0 k=0 k=0 的情况,可以发现对于阵营和派系的选择/限制相互独立,因此直接分开决策各个城市的阵营,以及各个学校的派系,再将答案相乘即可。
- 对于 k ≠ 0 k\ne0 k̸=0 的情况,将所涉及到的城市和学校拿出来单独 d p dp dp 即可,注意到拿出来的学校总人数不超过 k s ks ks ,本部分复杂度为 O ( k 2 s × M ) O(k^2s\times M) O(k2s×M) 。
- 时间复杂度 O ( N M + k 2 s × M ) O(NM+k^2s\times M) O(NM+k2s×M) 。
【代码】
#include<bits/stdc++.h> using namespace std; const int MAXN = 2005; const int MAXM = 2505; const int MAXS = 605; const int MAXK = 65; const int P = 998244353; template <typename T> void read(T &x) { x = 0; int f = 1; char ch = getchar(); for (; !isdigit(ch); ch = getchar()) if (ch == '-') f = -f; for (; isdigit(ch); ch = getchar()) x = x * 10 + ch - '0'; x *= f; } int n, m, k, cx, cy, dx, dy; int b[MAXN], s[MAXN]; int dp[2][MAXM][MAXS][2]; int dpc[MAXN][MAXM]; int dpd[MAXN][MAXM]; vector <int> a[MAXN]; int size[MAXN], kc[MAXN], kd[MAXN]; int tot, cnt[MAXK], type[MAXK]; int sc[MAXM], sd[MAXM]; void update(int &x, int y) { x += y; if (x >= P) x -= P; } int sumc(int l, int r) { if (l == 0) return sc[r]; else return (sc[r] - sc[l - 1] + P) % P; } int sumd(int l, int r) { if (l == 0) return sd[r]; else return (sd[r] - sd[l - 1] + P) % P; } int main() { int T; read(T); while (T--) { read(n), read(m); read(cx), read(cy), read(dx), read(dy); for (int i = 1; i <= m; i++) { a[i].clear(); kc[i] = -1; size[i] = 0; } int totsize = 0; for (int i = 1; i <= n; i++) { read(b[i]), read(s[i]); a[b[i]].push_back(i); kd[i] = -1; totsize += s[i]; size[b[i]] += s[i]; } read(k); for (int i = 1; i <= k; i++) { int x, y; read(x), read(y); kd[x] = y, kc[b[x]] = y; } if (totsize > cx + cy || totsize > dx + dy) { printf("0\n"); continue; } //cerr << clock() << endl; memset(dp, 0, sizeof(dp)); memset(dpc, 0, sizeof(dpc)); memset(dpd, 0, sizeof(dpd)); dpc[0][0] = dpd[0][0] = 1; for (int i = 1; i <= m; i++) for (int j = 0; j <= cx; j++) { dpc[i][j] = dpc[i - 1][j]; if (kc[i] != -1 || size[i] == 0) continue; if (j >= size[i]) update(dpc[i][j], dpc[i - 1][j - size[i]]); } for (int i = 1; i <= n; i++) for (int j = 0; j <= dx; j++) { dpd[i][j] = dpd[i - 1][j]; if (kd[i] != -1) continue; if (j >= s[i]) update(dpd[i][j], dpd[i - 1][j - s[i]]); } tot = 0; for (int i = 1; i <= m; i++) { if (kc[i] != -1) { type[++tot] = 0; cnt[tot] = size[i]; for (unsigned j = 0; j < a[i].size(); j++) if (kd[a[i][j]] != -1) { type[++tot] = a[i][j]; cnt[tot] = s[a[i][j]]; } } } //cerr << clock() << endl; dp[0][0][0][0] = 1; int goal = min(dx, tot * 10); for (int i = 1, now = 1, from = 0; i <= tot; i++, swap(now, from)) { memset(dp[now], 0, sizeof(dp[now])); for (int j = 0; j <= cx; j++) for (int k = 0; k <= goal; k++) { if (type[i]) { if (k >= cnt[i]) { if (kd[type[i]] != 0) update(dp[now][j][k][0], dp[from][j][k - cnt[i]][0]); if (kd[type[i]] != 2) update(dp[now][j][k][1], dp[from][j][k - cnt[i]][1]); } if (kd[type[i]] != 1) update(dp[now][j][k][0], dp[from][j][k][0]); if (kd[type[i]] != 3) update(dp[now][j][k][1], dp[from][j][k][1]); } else { if (j >= cnt[i]) { dp[now][j][k][0] = dp[from][j - cnt[i]][k][0] + dp[from][j - cnt[i]][k][1]; if (dp[now][j][k][0] >= P) dp[now][j][k][0] -= P; } dp[now][j][k][1] = dp[from][j][k][0] + dp[from][j][k][1]; if (dp[now][j][k][1] >= P) dp[now][j][k][1] -= P; } } } //cerr << clock() << endl; sc[0] = dpc[m][0]; for (int i = 1; i <= cx; i++) sc[i] = (sc[i - 1] + dpc[m][i]) % P; sd[0] = dpd[n][0]; for (int i = 1; i <= dx; i++) sd[i] = (sd[i - 1] + dpd[n][i]) % P; int ans = 0; for (int j = 0; j <= cx; j++) for (int k = 0; k <= goal; k++) { int tmp = (dp[tot & 1][j][k][0] + dp[tot & 1][j][k][1]) % P; update(ans, 1ll * tmp * sumc(max(0, totsize - cy - j), cx - j) % P * sumd(max(0, totsize - dy - k), dx - k) % P); } //cerr << clock() << endl; printf("%d\n", ans); } return 0; }