题目描述
您需要写一种数据结构(可参考题目标题),来维护一些数,其中需要提供以下操作:
插入xx数
删除xx数(若有多个相同的数,因只删除一个)
查询xx数的排名(排名定义为比当前数小的数的个数+1+1。若有多个相同的数,因输出最小的排名)
查询排名为xx的数
求xx的前驱(前驱定义为小于xx,且最大的数)
求xx的后继(后继定义为大于xx,且最小的数)
输入输出格式
输入格式:
第一行为nn,表示操作的个数,下面nn行每行有两个数optopt和xx,optopt表示操作的序号( 1 \leq opt \leq 6 1≤opt≤6 )
输出格式:
对于操作3,4,5,63,4,5,6每行输出一个数,表示对应答案
输入输出样例
输入样例#1:
10
1 106465
4 1
1 317721
1 460929
1 644985
1 84185
1 89851
6 81968
1 492737
5 493598
输出样例#1:
106465
84185
492737
平衡树的模板题,写的是Treap(因为书上写的就是Treap)。水(抄)过了一道模板题后,感觉还是挺懵逼的……
#include<cstdio>
#include<cstdlib>
const int N=1e5+10;
const int INF=0x7fffffff;
int n,tot,rt;
struct Treap{
int l,r,val,dat,cnt,sz;
}t[N];
int New(int val)
{
t[++tot].val=val;
t[tot].dat=rand();
t[tot].cnt=t[tot].sz=1;
return tot;
}
void update(int p)
{
t[p].sz=t[t[p].l].sz+t[t[p].r].sz+t[p].cnt;
}
void build()
{
New(-INF),New(INF);
rt=1;t[1].r=2;
update(rt);
}
void zig(int &p)
{
int q=t[p].l;
t[p].l=t[q].r;t[q].r=p;p=q;
update(t[p].r);update(p);
}
void zag(int &p)
{
int q=t[p].r;
t[p].r=t[q].l;t[q].l=p;p=q;
update(t[p].l);update(p);
}
void Insert(int &p,int val)
{
if(!p){p=New(val);return;}
if(val==t[p].val){t[p].cnt++;update(p);return;}
if(val<t[p].val)
{
Insert(t[p].l,val);
if(t[p].dat<t[t[p].l].dat)zig(p);
}
else
{
Insert(t[p].r,val);
if(t[p].dat>t[t[p].r].dat)zag(p);
}
update(p);
}
void Remove(int &p,int val)
{
if(!p)return;
if(val==t[p].val)
{
if(t[p].cnt>1)
{
t[p].cnt--;update(p);return;
}
if(t[p].l||t[p].r)
{
if(!t[p].r||t[t[p].l].dat>t[t[p].r].dat)zig(p),Remove(t[p].r,val);
else zag(p),Remove(t[p].l,val);
update(p);
}
else p=0;return;
}
if(val<t[p].val)Remove(t[p].l,val);
else Remove(t[p].r,val);
update(p);
}
int getr(int p,int val)
{
if(!p)return 0;
if(val==t[p].val)return t[t[p].l].sz+1;
if(val<t[p].val)return getr(t[p].l,val);
return getr(t[p].r,val)+t[t[p].l].sz+t[p].cnt;
}
int getv(int p,int r)
{
if(!p)return INF;
if(t[t[p].l].sz>=r)return getv(t[p].l,r);
if(t[t[p].l].sz+t[p].cnt>=r)return t[p].val;
return getv(t[p].r,r-t[t[p].l].sz-t[p].cnt);
}
int getpre(int val)
{
int p=rt,ans=1;
while(p)
{
if(val==t[p].val)
{
if(t[p].l>0)
{
p=t[p].l;
while(t[p].r>0)p=t[p].r;
ans=p;
}
break;
}
if(t[p].val<val&&t[p].val>t[ans].val)ans=p;
p=val<t[p].val?t[p].l:t[p].r;
}
return t[ans].val;
}
int getnx(int val)
{
int p=rt,ans=2;
while(p)
{
if(val==t[p].val)
{
if(t[p].r>0)
{
p=t[p].r;
while(t[p].l>0)p=t[p].l;
ans=p;
}
break;
}
if(t[p].val>val&&t[p].val<t[ans].val)ans=p;
p=val<t[p].val?t[p].l:t[p].r;
}
return t[ans].val;
}
int main()
{
//freopen("in.txt","r",stdin);
build();
int opt,x;
scanf("%d",&n);
while(n--)
{
scanf("%d%d",&opt,&x);
switch(opt)
{
case 1:Insert(rt,x);break;
case 2:Remove(rt,x);break;
case 3:printf("%d\n",getr(rt,x)-1);break;
case 4:printf("%d\n",getv(rt,x+1));break;
case 5:printf("%d\n",getpre(x));break;
case 6:printf("%d\n",getnx(x));break;
}
}
return 0;
}
总结
我自己都不太明白……