Problem
Mean
编写一个支持插入、删除元素,查询元素排名,查询相应排名的元素,查询元素前驱与后继的数据结构。
有重复元素存在,且查询前驱后继的元素可能未出现在树中。
Analysis
Treap和Splay都可做。
Splay随便搞一搞就可以了,然而没打过的我只会口胡……
于是打了Treap。
重复元素可以选择在保存在同一节点中,在Struct里新加一个cnt来保存当前节点元素个数。但我写的是插入一个新节点。
难点可能在查询前驱与后继。如果元素没有出现在当前树中,可以先把它加进去,查询完以后再删掉(我觉得这样比较方便)。前驱即kth(o,rank(o,x)-1),后继即kth(o,rank(o,x)+1).但是由于重复元素的存在,既是插入时都在右子树插入重复元素,在调整后仍有可能出现前驱或等于当前元素的情况,所以要不断查询知道不等为止。
类似的,查询rank同样需要不断往前查找,直到出现不相等的元素。
Code
#include<cstdio>
#include<cstdlib>
int n,opt,x;
struct Node{
Node* ch[2];
int r,v,s;
Node(int v):v(v) {ch[0]=ch[1]=NULL;r=rand();s=1;}
int cmp(int x){
if(x==v) return -1;
return x>v;
}
void maintain(){
s=1;
if(ch[0]!=NULL) s+=ch[0]->s;
if(ch[1]!=NULL) s+=ch[1]->s;
}
};
void rotate(Node* &o,int d){
Node* k=o->ch[d^1];o->ch[d^1]=k->ch[d];k->ch[d]=o;
o->maintain();k->maintain();o=k;
}
void insert(Node* &o,int x){
if(o==NULL) o=new Node(x);
else{
int d=x>=o->v;
insert(o->ch[d],x);
if(o->ch[d]->r>o->r) rotate(o,d^1);
}
o->maintain();
}
void remove(Node* &o,int x){
int d=o->cmp(x);
if(d==-1){
if(o->ch[0]==NULL) o=o->ch[1];
else if(o->ch[1]==NULL) o=o->ch[0];
else{
int d2=o->ch[0]->r>o->ch[1]->r;
rotate(o,d2);remove(o->ch[d2],x);
}
}else remove(o->ch[d],x);
if(o!=NULL) o->maintain();
}
int rank(Node* o,int x){
int d=o->cmp(x),s=o->ch[0]==NULL?0:o->ch[0]->s;
if(d==-1) return s+1;
if(d==0) return rank(o->ch[0],x);
return s+1+rank(o->ch[1],x);
}
int kth(Node* o,int k){
int s=o->ch[0]==NULL?0:o->ch[0]->s;
if(s+1==k) return o->v;
if(s>=k) return kth(o->ch[0],k);
return kth(o->ch[1],k-s-1);
}
bool find(Node* o,int x){
while(o!=NULL){
int d=o->cmp(x);
if(d==-1) return 1;
else o=o->ch[d];
}
return 0;
}
int main(){
Node* o=NULL;
scanf("%d",&n);
for(int i=0;i<n;i++){
scanf("%d%d",&opt,&x);
if(opt==1) insert(o,x);
if(opt==2) remove(o,x);
if(opt==3){
int r=rank(o,x);
while(r-1 && kth(o,r-1)==x) r--;
printf("%d\n",r);
}
if(opt==4) printf("%d\n",kth(o,x));
if(opt==5){
bool p=find(o,x);
if(!p) insert(o,x);
int f=x,Rank=rank(o,x),u=0;
while(f==x) f=kth(o,Rank-(++u));
if(!p) remove(o,x);
printf("%d\n",f);
}
if(opt==6){
bool p=find(o,x);
if(!p) insert(o,x);
int f=x,Rank=rank(o,x),u=0;
while(f==x) f=kth(o,Rank+(++u));
if(!p) remove(o,x);
printf("%d\n",f);
}
}
return 0;
}