原题链接
组合数 + 容斥思想,关键在于如何使每个不合法方案只被计算一次.
代码如下:
#include <bits/stdc++.h>
using namespace std;
#define pii pair<int, int>
#define ll long long
inline int read() {
int x = 0, f = 0; char ch = getchar();
while (!isdigit(ch)) f = ch == '-', ch = getchar();
while (isdigit(ch)) x = (x << 3) + (x << 1) + (ch ^ 48), ch = getchar();
return f ? -x : x;
}
const int N = 200010, K = 2010;
const int mod = 998244353;
int n, m, k, fac[N], inv[N], F, G;
struct Node {
pii p; int f, g;
bool operator < (const Node &that) const {
if (p.first == that.p.first) return p.second < that.p.second;
return p.first < that.p.first;
}
};
Node nd[N];
void AddMod(int &p, int k) { p = ((p + k) % mod + mod) % mod; }
ll qpow(ll a, ll b, ll p) {
ll res = 1ll;
while (b) {
if (b & 1) res = (res * a) % p;
a = (a * a) % p;
b >>= 1ll;
}
return res;
}
ll C(ll n, ll m) {
if (m < 0 || n < 0 || m > n) return 0ll;
return (ll)fac[n] * inv[m] % mod * inv[n - m] % mod;
}
int main() {
n = read(); m = read(); k = read();
fac[0] = 1;
for (ll i = 1; i < N; ++i) fac[i] = (ll)fac[i - 1] * i % mod;
inv[N - 1] = qpow(fac[N - 1], mod - 2, mod);
for (ll i = N - 2; i >= 0; --i) inv[i] = (ll)inv[i + 1] * (i + 1) % mod;
for (int i = 1; i <= k; ++i) {
nd[i].p.first = read(); nd[i].p.second = read();
nd[i].f = C(nd[i].p.first + nd[i].p.second, nd[i].p.first);
nd[i].g = C(n - nd[i].p.first + m - nd[i].p.second, n - nd[i].p.first);
}
sort(nd + 1, nd + k + 1);
for (int i = 1; i <= k; ++i) {
for (int j = 1; j < i; ++j) {
if (nd[j].p.second <= nd[i].p.second) {
int d = nd[j].f * C(nd[i].p.first - nd[j].p.first + nd[i].p.second - nd[j].p.second, nd[i].p.first - nd[j].p.first) % mod;
AddMod(nd[i].f, -d);
}
}
}
reverse(nd + 1, nd + k + 1);
for (int i = 1; i <= k; ++i) {
for (int j = 1; j < i; ++j) {
if (nd[j].p.second >= nd[i].p.second) {
int d = nd[j].g * C(nd[j].p.first - nd[i].p.first + nd[j].p.second - nd[i].p.second, nd[j].p.first - nd[i].p.first) % mod;
AddMod(nd[i].g, -d);
}
}
}
int ans = 0ll;
for (int i = 1; i <= k; ++i) AddMod(F, nd[i].f);
for (int i = 1; i <= k; ++i) AddMod(G, nd[i].g);
for (int i = 1; i <= k; ++i) {
AddMod(ans, (ll)nd[i].f * (((G - nd[i].g) % mod + mod) % mod) % mod);
} // 分配率优化
printf("%d\n", ans);
return 0;
}