题目链接:点击查看
题目大意:给出 n 个区间 [ l[ i ] , r[ i ] ] ,再给出 m 个限制 ( a[ i ] , b[ i ] ),求在 n 个区间中能选出多少种子集 S,满足 ,且任意一条限制都不能同时出现在 S 中
题目分析:先不考虑限制,在输入时可以对 n 个区间差分一下,差分数组记为 cnt[ i ] ,然后求一下前缀和,就可以枚举子集的大小 | S | = i,那么当子集大小 | S | 为 i 时的种类数就是 C( cnt[ i ] , i )
对于一个限制和一个子集共有四种情况:
- a[ i ] ∈ S,b[ i ] ∈ S
- a[ i ] ∈ S , b[ i ] ∉ S
- a[ i ] ∉ S , b[ i ] ∈ S
- a[ i ] ∉ S , b[ i ] ∉ S
显然只有第一种情况是不符合的,剩下的三种情况都是符合条件的,这样一来正难则反,我们可以求出所有不符合限制的子集数,然后容斥就好了,最后用全集减去不符合条件的就是答案了
因为在这个题目中,设 t 为同时生效的限制数,则全集为 t = 0 时的答案,那么这个题目就转换为了求:
( t = 0 时的答案 ) - ( t = 1 时的答案 ) + ( t = 2 时的答案 ) - ( t = 3 时的答案 ) ....
因为 m 比较小,所以直接状压就好,t 个限制同时生效,也就是是需要求 t 个区间的交,假设区间的交为 [ L , R ],那么答案显然就是,这里突然出现的 tot 是 t 个限制总共涉及了多少个点
然后实现就好了,预处理一下逆元和前缀和,时间复杂度就是 m*2^m + n*m 的了,我为了偷懒加了个set,所以复杂度变成了 m * 2^m * log(20) + n * m,就当带了点常数吧
代码:
#include<iostream>
#include<cstdio>
#include<string>
#include<ctime>
#include<cmath>
#include<cstring>
#include<algorithm>
#include<stack>
#include<climits>
#include<queue>
#include<map>
#include<set>
#include<sstream>
#include<cassert>
#include<bitset>
using namespace std;
typedef long long LL;
typedef unsigned long long ull;
const int inf=0x3f3f3f3f;
const int N=3e5+100;
const int mod=998244353;
int l[N],r[N],cnt[N],a[N],b[N];
LL fac[N],inv[N],sum[50][N];
LL q_pow(LL a,LL b)
{
LL ans=1;
while(b)
{
if(b&1)
ans=ans*a%mod;
a=a*a%mod;
b>>=1;
}
return ans;
}
LL C(int n,int m)
{
if(n<m||m<0)
return 0;
return fac[n]*inv[m]%mod*inv[n-m]%mod;
}
void init()
{
fac[0]=1;
for(int i=1;i<N;i++)
fac[i]=fac[i-1]*i%mod;
inv[N-1]=q_pow(fac[N-1],mod-2);
for(int i=N-2;i>=0;i--)
inv[i]=inv[i+1]*(i+1)%mod;
}
int main()
{
#ifndef ONLINE_JUDGE
// freopen("data.in.txt","r",stdin);
// freopen("data.out.txt","w",stdout);
#endif
// ios::sync_with_stdio(false);
init();
int n,m;
scanf("%d%d",&n,&m);
for(int i=1;i<=n;i++)
{
scanf("%d%d",l+i,r+i);
cnt[l[i]]++,cnt[r[i]+1]--;
}
for(int i=1;i<=n;i++)
cnt[i]+=cnt[i-1];
for(int i=0;i<m;i++)
scanf("%d%d",a+i,b+i);
for(int i=0;i<=2*m;i++)
for(int j=1;j<=n;j++)
sum[i][j]=(sum[i][j-1]+C(cnt[j]-i,j-i))%mod;
LL ans=0;
for(int i=0;i<1<<m;i++)
{
set<int>st;
int L=1,R=n;//交集
for(int j=0;j<m;j++)
if((i>>j)&1)
{
st.insert(a[j]);
st.insert(b[j]);
L=max(L,l[a[j]]);
L=max(L,l[b[j]]);
R=min(R,r[a[j]]);
R=min(R,r[b[j]]);
}
int nn=st.size();//总共涉及了多少个点
LL temp=0;
if(L<=R)
temp=(sum[nn][R]-sum[nn][L-1]+mod)%mod;
if(__builtin_popcount(i)&1)
ans=(ans-temp+mod)%mod;
else
ans=(ans+temp)%mod;
}
printf("%lld\n",ans);
return 0;
}