Description
送你在数轴上的n 个区间和m 个关键点, 你可以决定每个区间选或不选, 问有多少种方案覆盖
所有的关键点. 对1000000009 取模.
Input
第一行两个整数n;m, 分别表示区间个数和关键点个数.
接下来n 行, 每行两个整数li; ri, 表示一个区间[li; ri].
接下来m 行, 每行一个整数, 第i 行表示表示第i 个关键点xi.
Output
输出一行一个整数, 表示答案.
Sample Input
4 4
3 8
1 6
3 8
2 7
8
4
6
3
Sample Output
12
Constraints
对于100% 的数据, 1<=n.m<=500000;1<=xi<=109;1<=li<=ri<=109.
Solution
首先离散化
将关键点排序
设dp:f[i]表示包含前i个关键点的方案数
考虑每个区间,包含了[l,r]的关键点
那么,f[l-1]~f[r-1]可以累加到f[r]上,f[r]以后的可以*2
这个用线段树维护就好了
Code
#include<cstdio>
#include<cstring>
#include<algorithm>
#define fo(i,a,b) for(int i=a;i<=b;i++)
#define fd(i,a,b) for(int i=a;i>=b;i--)
#define N 2001000
#define mo 1000000009
#define ll long long
using namespace std;
int m,n,b[N],tot=0,las[N],bz[N],d[N],x,y,t[N],lz[N];
struct node{
int x,y,z;
}a[N],c[N];
bool cnt(node a,node b){return a.x<b.x;}
bool cmt(node a,node b){return a.x<b.x||(a.x==b.x&&a.y>b.y);}
void read(int &x)
{
char c=getchar();x=0;
for(;c<'0'||c>'9';c=getchar());
for(;c>='0'&&c<='9';c=getchar()) x=x*10+c-48;
}
void down(int v)
{
if(lz[v]==1) return;
t[v+v]=(ll)t[v+v]*lz[v]%mo;
t[v+v+1]=(ll)t[v+v+1]*lz[v]%mo;
lz[v+v]=(ll)lz[v+v]*lz[v]%mo;
lz[v+v+1]=(ll)lz[v+v+1]*lz[v]%mo;
lz[v]=1;
}
void ins(int v,int i,int j,int x,int y)
{
if(i==x&&j==y)
{
t[v]=t[v]+t[v];lz[v]=lz[v]+lz[v];
t[v]=t[v]>mo?t[v]-mo:t[v];
lz[v]=lz[v]>mo?lz[v]-mo:lz[v];
return;
}
int m=(i+j)>>1;down(v);
if(y<=m) ins(v+v,i,m,x,y);
else if(x>m) ins(v+v+1,m+1,j,x,y);
else ins(v+v,i,m,x,m),ins(v+v+1,m+1,j,m+1,y);
t[v]=t[v+v]+t[v+v+1];
t[v]=t[v]>mo?t[v]-mo:t[v];
}
int get(int v,int i,int j,int x,int y)
{
if(i>=x&&j<=y) return t[v];
int m=(i+j)>>1;down(v);
int sum;
if(y<=m) return get(v+v,i,m,x,y);
else if(x>m) return get(v+v+1,m+1,j,x,y);
else sum=get(v+v,i,m,x,m)+get(v+v+1,m+1,j,m+1,y);
sum=sum>mo?sum-mo:sum;
return sum;
}
void change(int v,int i,int j,int x,int y)
{
if(i==j)
{
t[v]=y;lz[v]=1;
return;
}
int m=(i+j)>>1;down(v);
if(x<=m) change(v+v,i,m,x,y);
else change(v+v+1,m+1,j,x,y);
t[v]=(t[v+v]+t[v+v+1])%mo;
}
int main()
{
freopen("xmasinterval.in","r",stdin);
freopen("xmasinterval.out","w",stdout);
scanf("%d%d",&n,&m);
fo(i,1,n)
{
read(a[i].x);read(a[i].y);
c[i*2-1].x=a[i].x,c[i*2].x=a[i].y;
c[i*2-1].y=c[i*2].y=i;
c[i*2-1].z=1;c[i*2].z=2;
}
fo(i,1,m) read(d[i]);
sort(d+1,d+m+1);
fo(i,1,m) if(d[i]!=d[i-1]) b[++b[0]]=d[i];
m=b[0];
fo(i,1,m) c[i+n*2].x=b[i],c[i+n*2].y=i,c[i+n*2].z=3;
sort(c+1,c+n+n+m+1,cnt);
fo(i,1,n+n+m+1)
{
if(c[i].x!=c[i-1].x) tot++;
if(c[i].z==1) a[c[i].y].x=tot;
if(c[i].z==2) a[c[i].y].y=tot;
if(c[i].z==3) b[c[i].y]=tot;
}
int jy=0;
fo(i,1,tot)
{
while(b[jy+1]<i&&jy<m) jy++;
las[i]=jy;
bz[i]=jy;
if(b[jy+1]==i) bz[i]++;
}
sort(a+1,a+n+1,cmt);
fo(i,1,N-1) lz[i]=1;
change(1,0,m,0,1);
fo(i,1,n)
{
ins(1,0,m,bz[a[i].y],m);
jy=get(1,0,m,las[a[i].x],bz[a[i].y]);
change(1,0,m,bz[a[i].y],jy);
}
printf("%d\n",get(1,1,m,m,m));
}