您需要写一种数据结构(可参考题目标题),来维护一个有序数列,其中需要提供以下操作:
①查询k在区间内的排名
②查询区间内排名为k的值
③修改某一位值上的数值
④查询k在区间内的前驱(前驱定义为严格小于x,且最大的数,若不存在输出-2147483647)
⑤查询k在区间内的后继(后继定义为严格大于x,且最小的数,若不存在输出2147483647)
#注意上面两条要求和tyvj或者bzoj不一样,请注意
输入格式
第一行两个数 n,m 表示长度为n的有序序列和m个操作
第二行有n个数,表示有序序列
下面有m行,opt表示操作标号
若opt=1 则为操作1,之后有三个数l,r,k 表示查询k在区间[l,r]的排名
若opt=2 则为操作2,之后有三个数l,r,k 表示查询区间[l,r]内排名为k的数
若opt=3 则为操作3,之后有两个数pos,k 表示将pos位置的数修改为k
若opt=4 则为操作4,之后有三个数l,r,k 表示查询区间[l,r]内k的前驱
若opt=5 则为操作5,之后有三个数l,r,k 表示查询区间[l,r]内k的后继
输出格式
对于操作1,2,4,5各输出一行,表示查询结果
解法一:
这道题最经典的做法便是线段树套平衡树。
空间复杂度为O(nlogn)。
也就是说,我们用一颗线段树管理下标区间 [L,R],然后每个线段树的节点都是一颗平衡树,平衡树中存储了所有下标在 [L,R] 之间的所有数的信息。这样的话,所有区间都可以被分为 O(logN) 个线段树的节点,区间问题便转化为了几个平衡树上的问题。
然后我们考虑如何具体实现各个操作:
1、查询区间内一个数的排名:在线段树上找到区间对应的节点,然后每个节点的平衡树内查询对应数的排名并求和。时间复杂度 O(log2N)
2、查询区间内排名为k的数是几:由于这项操作在线段树上不可加。所以我们考虑转换为判定一个数是不是区间内排名为k的。这个可以在 O(log2N) 的时间内通过操作1完成。那么我们考虑二分答案,就可以解决这个问题。时间复杂度O(log3N)
3、单点修改:我们在线段树上找到所有覆盖这个点的区间,然后在所有区间对应的平衡树中删除原数,加入新数即可。时间复杂度 O(log2N)
4,5、 查询区间内一个数前驱、后继:这是平衡树上的经典问题。我们只需要对于所有区间分别查询,然后答案相应的取 max min 即可。时间复杂度O(log2N)
解法二:
带修主席树,空间复杂度为O(n log2n)
但是各项操作的时间复杂度均为O(n log2n)
懒得用主席树写了,就用线段树套平衡树写了,这里平衡树用SBT。
#include<iostream>
#include<cstdio>
#include<algorithm>
#include<cstring>
#include<cmath>
#include<string>
#include<queue>
#include<set>
#define ll long long
#define llu unsigned ll
using namespace std;
const int inf=2147483647;
const int maxn=50100;
struct SBT
{
int val;
int lc,rc;
int si;
}t[maxn*22];
int tot;
void pushup(int p)
{
t[p].si=t[t[p].lc].si+t[t[p].rc].si+1;
}
int newnode(int x)
{
int p=++tot;
t[p].lc=t[p].rc=0;
t[p].si=1;
t[p].val=x;
return p;
}
void R(int &p)
{
int q=t[p].lc;
t[p].lc=t[q].rc;
t[q].rc=p;
p=q;
pushup(t[p].rc);
pushup(p);
}
void L(int &p)
{
int q=t[p].rc;
t[p].rc=t[q].lc;
t[q].lc=p;
p=q;
pushup(t[p].lc);
pushup(p);
}
void RL(int &p)
{
R(t[p].rc);
L(p);
}
void LR(int &p)
{
L(t[p].lc);
R(p);
}
int getmin(int p)
{
while(t[p].lc) p=t[p].lc;
return p;
}
int getmax(int p)
{
while(t[p].rc) p=t[p].rc;
return p;
}
void maintain(int &p,bool flag)
{
if(!flag)
{
if(t[t[t[p].lc].lc].si>t[t[p].rc].si)
R(p);
else if(t[t[t[p].lc].rc].si>t[t[p].rc].si)
LR(p);
else return ;
}
else
{
if(t[t[t[p].rc].rc].si>t[t[p].lc].si)
L(p);
else if(t[t[t[p].rc].lc].si>t[t[p].lc].si)
RL(p);
else return ;
}
maintain(t[p].lc,false);
maintain(t[p].rc,true);
maintain(p,true);
maintain(p,false);
}
void _insert(int &now,int val)
{
if(!now)
{
now=newnode(val);
return ;
}
if(val<t[now].val)
{
_insert(t[now].lc,val);
maintain(now,false);
}
else
{
_insert(t[now].rc,val);
maintain(now,true);
}
pushup(now);
}
void del(int &p,int x)
{
if(x<t[p].val)
del(t[p].lc,x);
else if(x>t[p].val)
del(t[p].rc,x);
else
{
if(t[p].lc&&t[p].rc)
{
int q=getmax(t[p].lc);
t[p].val=t[q].val;
del(t[p].lc,t[q].val);
}
else
{
if(t[p].lc) p=t[p].lc;
else p=t[p].rc;
return ;
}
}
pushup(p);
}
int get_rank(int root,int x)
{
int now=root,ans=0;
while(now)
{
if(t[now].val<x) ans+=t[t[now].lc].si+1,now=t[now].rc;
else now=t[now].lc;
}
return ans+1;
}
int get_val(int root,int k)
{
int now=root;
while(1)
{
if(t[t[now].lc].si+1==k) return t[now].val;
else if(t[t[now].lc].si>=k) now=t[now].lc;
else k-=t[t[now].lc].si+1,now=t[now].rc;
}
}
int get_front(int root,int x)
{
int now=root,ans=-inf;
while(now)
{
if(t[now].val<x) ans=max(ans,t[now].val),now=t[now].rc;
else now=t[now].lc;
}
return ans;
}
int get_behind(int root,int x)
{
int now=root,ans=inf;
while(now)
{
if(t[now].val>x) ans=min(ans,t[now].val),now=t[now].lc;
else now=t[now].rc;
}
return ans;
}
struct node
{
int l,r,rt;
}tt[maxn<<2];
int a[maxn];
void build(int l,int r,int cnt)
{
tt[cnt].l=l,tt[cnt].r=r;
tt[cnt].rt=0;
for(int i=l;i<=r;i++)
_insert(tt[cnt].rt,a[i]);
if(l==r) return ;
int mid=(l+r)>>1;
build(l,mid,cnt<<1);
build(mid+1,r,cnt<<1|1);
}
int ask_rk(int l,int r,int k,int cnt)
{
if(l<=tt[cnt].l&&tt[cnt].r<=r)
return get_rank(tt[cnt].rt,k)-1;
int ans=0;
if(l<=tt[cnt<<1].r) ans+=ask_rk(l,r,k,cnt<<1);
if(r>=tt[cnt<<1|1].l) ans+=ask_rk(l,r,k,cnt<<1|1);
return ans;
}
int ask1(int l,int r,int k)
{
return ask_rk(l,r,k,1)+1;
}
int ask2(int nl,int nr,int k)
{
int l=0,r=1e8;
int mid,pos=0;
while(l<=r)
{
mid=(l+r)>>1;
if(ask1(nl,nr,mid)<=k) pos=mid,l=mid+1;
else r=mid-1;
}
return pos;
}
void change(int cnt,int x,int y)
{
del(tt[cnt].rt,a[x]);
_insert(tt[cnt].rt,y);
if(tt[cnt].l==tt[cnt].r) return ;
if(x<=tt[cnt<<1].r) change(cnt<<1,x,y);
else change(cnt<<1|1,x,y);
}
void do3(int x,int y)
{
change(1,x,y);
a[x]=y;
}
int ask4(int l,int r,int cnt,int k)
{
if(l<=tt[cnt].l&&tt[cnt].r<=r)
return get_front(tt[cnt].rt,k);
int ans=-inf;
if(l<=tt[cnt<<1].r) ans=max(ans,ask4(l,r,cnt<<1,k));
if(r>=tt[cnt<<1|1].l) ans=max(ans,ask4(l,r,cnt<<1|1,k));
return ans;
}
int ask5(int l,int r,int cnt,int k)
{
if(l<=tt[cnt].l&&tt[cnt].r<=r)
return get_behind(tt[cnt].rt,k);
int ans=inf;
if(l<=tt[cnt<<1].r) ans=min(ans,ask5(l,r,cnt<<1,k));
if(r>=tt[cnt<<1|1].l) ans=min(ans,ask5(l,r,cnt<<1|1,k));
return ans;
}
int main(void)
{
int n,m;
scanf("%d%d",&n,&m);
for(int i=1;i<=n;i++)
scanf("%d",&a[i]);
build(1,n,1);
int op,l,r,k,pos;
for(int i=1;i<=m;i++)
{
scanf("%d",&op);
if(op==1)
{
scanf("%d%d%d",&l,&r,&k);
printf("%d\n",ask1(l,r,k));
}
else if(op==2)
{
scanf("%d%d%d",&l,&r,&k);
printf("%d\n",ask2(l,r,k));
}
else if(op==3)
{
scanf("%d%d",&pos,&k);
do3(pos,k);
}
else if(op==4)
{
scanf("%d%d%d",&l,&r,&k);
printf("%d\n",ask4(l,r,1,k));
}
else
{
scanf("%d%d%d",&l,&r,&k);
printf("%d\n",ask5(l,r,1,k));
}
}
return 0;
}