题目比较简洁:
您需要写一种数据结构(可参考题目标题),来维护一个有序数列,其中需要提供以下操作:
1.查询k在区间内的排名
2.查询区间内排名为k的值
3.修改某一位值上的数值
4.查询k在区间内的前驱(前驱定义为小于x,且最大的数)
5.查询k在区间内的后继(后继定义为大于x,且最小的数)
又有区间,又有排名,所以明显是树套树了,反正我写了区间线段树套平衡树。
相当于在每个线段树上的点都开一棵splay,表示这个区间内的数的排名。
因为splay是动态开点的,所以时间O(nlog^2n)空间nlog^2n,都可以接受。
对于一操作,大概就是把对应区间在线段树上的点的k的排名求出来,然后加起来。
那么二操作二分下答案,用一函数判断合法性就可以了。
修改比较好做,删了再增加。
前驱即把各个区间的答案(如果有的话)取最大值。
后继同上,但取最小值。
代码略长,做好心理准备。
code:
#include<cstdio>
#include<cstdlib>
#include<cstring>
#include<iostream>
#include<queue>
#include<algorithm>
using namespace std;
const int inf=(1<<28);
struct trnode{
int n,d,c,fa,son[2],lc,rc;
}tr[2000010];int root[100010],trlen=0;
queue <int> q;
int n,m,a[50010];
int read()
{
int x=0,f=1;char ch=getchar();
while(ch<'0'||ch>'9'){if(ch=='-')f=-1;ch=getchar();}
while(ch>='0'&&ch<='9'){x=x*10+ch-'0';ch=getchar();}
return x*f;
}
void update(int x)
{
int lc=tr[x].son[0],rc=tr[x].son[1];
tr[x].c=tr[lc].c+tr[rc].c+tr[x].n;
}
int add(int d,int fa)
{
int x=q.front();q.pop();
tr[x].d=d;tr[x].n=tr[x].c=1;tr[x].fa=fa;
tr[x].son[0]=tr[x].son[1]=0;
if(fa!=0) tr[fa].son[d<tr[fa].d?0:1]=x;
return x;
}
void rotate(int x)
{
int y=tr[x].fa,z=tr[y].fa,w,R,r;
w=tr[y].son[0]==x?1:0;
R=y;r=tr[x].son[w];
tr[R].son[1-w]=r;
if(r!=0) tr[r].fa=R;
R=z;r=x;
tr[R].son[tr[z].son[0]==y?0:1]=r;
if(r!=0) tr[r].fa=R;
R=x;r=y;
tr[R].son[w]=r;
if(r!=0) tr[r].fa=R;
update(y);update(x);
}
void splay(int x,int fa,int rt)
{
while(tr[x].fa!=fa)
{
int y=tr[x].fa,z=tr[y].fa;
if(z==fa) rotate(x);
else
if((tr[z].son[0]==y)==(tr[y].son[0]==x)) rotate(y),rotate(x);
else rotate(x),rotate(x);
}
if(fa==0) root[rt]=x;
}
int findid(int d,int rt)
{
int x=root[rt];
while(tr[x].d!=d)
{
int lc=tr[x].son[0],rc=tr[x].son[1];
if(d<tr[x].d)
if(lc!=0) x=lc;
else break;
else
if(rc!=0) x=rc;
else break;
}
return x;
}
void ins(int d,int rt)
{
if(root[rt]==0){root[rt]=add(d,0);return;}
int x=findid(d,rt);
if(tr[x].d==d) tr[x].n++;
else add(d,x);
update(x);splay(x,0,rt);
}
void del(int d,int rt)
{
int x=findid(d,rt);splay(x,0,rt);
if(tr[x].n>1){tr[x].n--;update(x);return;}
if(tr[x].son[0]==0&&tr[x].son[1]==0){root[rt]=0;q.push(x);return;}
if(tr[x].son[0]!=0&&tr[x].son[1]==0){root[rt]=tr[x].son[0];tr[root[rt]].fa=0;q.push(x);return;}
if(tr[x].son[0]==0&&tr[x].son[1]!=0){root[rt]=tr[x].son[1];tr[root[rt]].fa=0;q.push(x);return;}
int p=tr[x].son[0];
while(tr[p].son[1]!=0) p=tr[p].son[1];
splay(p,x,rt);
root[rt]=p;tr[p].fa=0;q.push(x);
int R=p,r=tr[x].son[1];
tr[R].son[1]=r;
if(r!=0) tr[r].fa=R;
update(p);
}
int findqianqu(int d,int rt)
{
int x=findid(d,rt);splay(x,0,rt);
if(tr[x].d>=d)
if(tr[x].son[0]!=0)
{
x=tr[x].son[0];
while(tr[x].son[1]!=0) x=tr[x].son[1];
}
splay(x,0,rt);
return x;
}
int findhouji(int d,int rt)
{
int x=findid(d,rt);splay(x,0,rt);
if(tr[x].d<=d)
if(tr[x].son[1]!=0)
{
x=tr[x].son[1];
while(tr[x].son[0]!=0) x=tr[x].son[0];
}
splay(x,0,rt);
return x;
}
int splayrank(int d,int rt)
{
int x=findqianqu(d,rt);
if(tr[x].d>=d) return 0;
return tr[x].n+tr[tr[x].son[0]].c;
}
int bt(int l,int r)
{
int x=++trlen;
if(l!=r)
{
int mid=(l+r)/2;
tr[x].lc=bt(l,mid);
tr[x].rc=bt(mid+1,r);
}
return x;
}
void change(int x,int l,int r,int k,int c,int tmp,bool first)
{
if(!first) del(tmp,x);
ins(c,x);
if(l==r) return;
int mid=(l+r)/2,lc=tr[x].lc,rc=tr[x].rc;
if(k<=mid) change(lc,l,mid,k,c,tmp,first);
else change(rc,mid+1,r,k,c,tmp,first);
return;
}
int findrank(int x,int l,int r,int fl,int fr,int d)
{
if(l==fl&&r==fr) return splayrank(d,x);
int mid=(l+r)/2,lc=tr[x].lc,rc=tr[x].rc;
if(fr<=mid) return findrank(lc,l,mid,fl,fr,d);
if(fl>mid) return findrank(rc,mid+1,r,fl,fr,d);
return findrank(lc,l,mid,fl,mid,d)+findrank(rc,mid+1,r,mid+1,fr,d);
}
int findnum(int fl,int fr,int k)
{
int l=0,r=100000010,ans;
while(l<=r)
{
int mid=(l+r)/2;
int tmp=findrank(1,1,n,fl,fr,mid)+1;
if(tmp<=k) ans=mid,l=mid+1;
else r=mid-1;
}
return ans;
}
int solve1(int x,int l,int r,int fl,int fr,int d)
{
if(l==fl&&r==fr) return tr[findqianqu(d,x)].d;
int mid=(l+r)/2,lc=tr[x].lc,rc=tr[x].rc;
if(fr<=mid) return solve1(lc,l,mid,fl,fr,d);
if(fl>mid) return solve1(rc,mid+1,r,fl,fr,d);
int lcc=solve1(lc,l,mid,fl,mid,d),rcc=solve1(rc,mid+1,r,mid+1,fr,d),ans=-inf;
if(lcc<d) ans=max(ans,lcc);if(rcc<d) ans=max(ans,rcc);
return ans;
}
int solve2(int x,int l,int r,int fl,int fr,int d)
{
if(l==fl&&r==fr) return tr[findhouji(d,x)].d;
int mid=(l+r)/2,lc=tr[x].lc,rc=tr[x].rc;
if(fr<=mid) return solve2(lc,l,mid,fl,fr,d);
if(fl>mid) return solve2(rc,mid+1,r,fl,fr,d);
int lcc=solve2(lc,l,mid,fl,mid,d),rcc=solve2(rc,mid+1,r,mid+1,fr,d),ans=inf;
if(lcc>d) ans=min(ans,lcc);if(rcc>d) ans=min(ans,rcc);
return ans;
}
int main()
{
n=read();m=read();
while(!q.empty()) q.pop();
for(int i=1;i<=20*n;i++) q.push(i);
bt(1,n);
for(int i=1;i<=n;i++)
{
a[i]=read();
change(1,1,n,i,a[i],0,true);
}
while(m--)
{
int tmp;tmp=read();int l,r,k,c;
if(tmp!=3) l=read(),r=read(),k=read();
else k=read(),c=read();
if(tmp==1) printf("%d\n",findrank(1,1,n,l,r,k)+1);
if(tmp==2) printf("%d\n",findnum(l,r,k));
if(tmp==3){change(1,1,n,k,c,a[k],false);a[k]=c;}
if(tmp==4){printf("%d\n",solve1(1,1,n,l,r,k));}
if(tmp==5){printf("%d\n",solve2(1,1,n,l,r,k));}
}
}