Educational Codeforces Round 94 (Rated for Div. 2)
G Mercenaries
题意:从n(
n
≤
3
e
5
n\leq3e5
n≤3e5)个人中雇佣一些人组成队伍,雇佣第i个人的条件是:队伍中人数在[
l
i
l_{i}
li,
r
i
r_{i}
ri]中。有m(
m
≤
20
m\leq20
m≤20)对人有仇,不能同时雇佣。
思路:枚举人数,统计当前人数的雇佣方案。利用容斥原理除去不合法方案。
- 统计总方案数:设人数为i时,有 c n t i cnt_{i} cnti个人可以选,那么总共有 C c n t i i C_{cnt_{i}}^{i} Ccntii中选择方案。
- 统计冲突方案:当人数为i,有 c n t i cnt_{i} cnti个人可以选,其中有j个人在冲突序列中,那么选择这j个人的方案数为 C c n t i − j i − j C_{cnt_{i}-j}^{i-j} Ccnti−ji−j。预处理出冲突方案的前缀和。
- 枚举状态,计算出每种状态的人数上下界以及在冲突序列中的人数,利用前缀和进行容斥。
#include<iostream>
#include<cstdio>
#include<cstdlib>
#include<algorithm>
#include<cstring>
#include<cmath>
#include<vector>
using namespace std;
typedef long long LL;
const int N = 3e5 + 10;
const LL mod = 998244353;
LL f[N], inv[N], sum[42][N];
int L[N], R[N], cnt[N], h[N * 2], ha[N], vist[30];
struct PAIR
{
int x, y;
PAIR(){}
PAIR(int a, int b){x = a; y = b;}
};
vector<PAIR>pir;
LL get_inv(LL x)
{
LL c = 1;
LL p = mod - 2;
while(p)
{
if(p & 1) c = c * x % mod;
x = x * x % mod;
p >>= 1;
}
return c;
}
LL C(int n, int m)
{
if(n < m || n < 0 || m < 0) return 0;
return f[n] * inv[m] % mod * inv[n - m] % mod;
}
int main()
{
int n, m;
scanf("%d%d", &n, &m);
for(int i = 1; i <= n; i++)
{
scanf("%d%d", &L[i], &R[i]);
cnt[L[i]]++;
cnt[R[i] + 1]--;
}
for(int i = 1; i <= n; i++) cnt[i] += cnt[i - 1];
int sz = 0;
for(int i = 1; i <= m; i++)
{
int x, y;
scanf("%d%d", &x, &y);
pir.push_back(PAIR(x, y));
h[++sz] = x;
h[++sz] = y;
}
sort(h + 1, h + sz + 1);
sz = unique(h + 1, h + sz + 1) - h - 1;
for(int i = 1; i <= sz; i++) ha[h[i]] = i;
f[0] = 1;
for(int i = 1; i <= n; i++) f[i] = f[i - 1] * i % mod;
for(int i = 0; i <= n; i++) inv[i] = get_inv(f[i]);
LL ans = 0;
for(int i = 1; i <= n; i++) ans = (ans + C(cnt[i], i)) % mod;
for(int i = 0; i <= m * 2; i++)
{
for(int j = 1; j <= n; j++)
{
if(cnt[j] >= i && j >= i) sum[i][j] = C(cnt[j] - i, j - i);
sum[i][j] = (sum[i][j] + sum[i][j - 1]) % mod;
}
}
for(int i = 1; i < (1 << m); i++)
{
int l = 1;
int r = n;
int tot = 0, sztot = 0;;
for(int j = 1; j <= sz; j++) vist[j] = 0;
for(int j = 0; j < m; j++)
if(i & (1 << j))
{
int x = pir[j].x;
int y = pir[j].y;
vist[ha[x]] = 1;
vist[ha[y]] = 1;
l = max(l, L[x]);
l = max(l, L[y]);
r = min(r, R[x]);
r = min(r, R[y]);
sztot++;
}
for(int j = 1; j <= sz; j++) if(vist[j]) tot++;
if(l <= r)
{
LL now =(sum[tot][r] - sum[tot][l - 1] + mod) % mod;
if(sztot & 1) ans = (ans - now + mod) % mod;
else ans = (ans + now) % mod;
}
}
cout << ans << endl;
return 0;
}