大概是平衡树的木板了,一直是按大蓝皮书写的treap模板,然后膜别人博客的时候发现可以把相同的数合并在一个节点记录一下个数,就不用建一大堆节点了,觉得非常妙。
如果不合并的话还得记录一下max和min值判断子树中还有没有重复的当前的数(其实不用?好像直接返回极值就可以了)。
#include<iostream>
#include<cstdlib>
#include<cstdio>
#include<cstring>
#include<algorithm>
using namespace std;
struct node
{
int val,rank,size,maxi,mini;
node *son[2];
bool cmp(int val)
{
return this->val<val;
}
}Tnull,*null=&Tnull,*root=null;
node* newnode(node *&o,int val_)
{
o=new node;
o->son[0]=o->son[1]=null;
o->maxi=o->mini=o->val=val_;
o->rank=rand();
o->size=1;
}
void maintain(node *&o)
{
o->size=1;
o->maxi=o->mini=o->val;
for(int i=0;i<2;i++)
if(o->son[i]!=null)o->size+=o->son[i]->size,o->maxi=max(o->maxi,o->son[i]->maxi),o->mini=min(o->mini,o->son[i]->mini);
}
void rotate(node *&o,bool d)
{
node *t=o->son[d];
o->son[d]=t->son[d^1];
t->son[d^1]=o;
maintain(o);//
maintain(t);//
o=t;
}
node* insert(node *&o,int val)
{
if(o==null)o=newnode(o,val);
else
{
int d=o->cmp(val);
insert(o->son[d],val);
if(o->son[d]->rank > o->rank)
rotate(o,d);
}
maintain(o);
}
void erase(node *&o,int val)
{
if(o->val==val)
{
node *t=o;
if(o->son[0]!=null&&o->son[1]!=null)
{
bool d=o->son[0]->rank < o->son[1]->rank;
rotate(o,d);
erase(o->son[d^1],val);
}
else
{
if(o->son[0]!=null)o=o->son[0];
else o=o->son[1];
delete t;
}
}
else erase(o->son[o->cmp(val)],val);
if(o!=null)maintain(o);
}
int checkrank(node *o,int val)
{
if(o->val==val)
{
if(o->son[0]!=null&&o->son[0]->maxi==val)
return checkrank(o->son[0],val);
else return o->son[0]->size+1;
}
bool d=o->val < val;
if(d)return checkrank(o->son[1],val)+o->son[0]->size+1;
else return checkrank(o->son[0],val);
}
int checknum(node *o,int num)
{
int pre=o->son[0]!=null?o->son[0]->size:0;
if(pre+1==num)return o->val;
else
{
if(pre>=num)return checknum(o->son[0],num);
else return checknum(o->son[1],num-pre-1);
}
}
int getpre(node *o,int val)
{
if(o->val>=val)
return getpre(o->son[0],val);
else
{
if(o->son[1]!=null&&o->son[1]->mini<val)
return getpre(o->son[1],val);
else return o->val;
}
}
int getsuc(node *o,int val)
{
if(o->val<=val)
return getsuc(o->son[1],val);
else
{
if(o->son[0]!=null&&o->son[0]->maxi>val)
return getsuc(o->son[0],val);
else return o->val;
}
}
int main()
{
int n;
scanf("%d",&n);
for(int i=1;i<=n;i++)
{
int op,val;
scanf("%d%d",&op,&val);
if(op==1) insert(root,val);
else if(op==2) erase(root,val);
else if(op==3) printf("%d\n",checkrank(root,val));
else if(op==4) printf("%d\n",checknum(root,val));
else if(op==5) printf("%d\n",getpre(root,val));
else printf("%d\n",getsuc(root,val));
}
return 0;
}
合并的写法:
#include<iostream>
#include<cstdlib>
#include<cstdio>
#include<cstring>
#include<algorithm>
using namespace std;
struct node
{
int val,rank,num,size;
node *son[2];
bool cmp(int val_)
{
return val<val_;
}
}tnull,*null=&tnull,*root=null;
void newnode(node *&o,int val)
{
o=new node;
o->val=val;
o->num=1;
o->size=1;
o->son[0]=o->son[1]=null;
o->rank=rand();
}
void maintain(node *&o)
{
o->size=o->num;
for(int i=0;i<2;i++)
if(o->son[i]!=null)
o->size+=o->son[i]->size;
}
void rotate(node *&o,bool d)
{
node *t=o->son[d];
o->son[d]=t->son[d^1];
t->son[d^1]=o;
maintain(o);
maintain(t);
o=t;
}
void insert(node *&o,int val)
{
if(o==null)
{
newnode(o,val);
return;
}
if(val==o->val)
o->num++,o->size++;
else
{
bool d;
insert(o->son[d=o->cmp(val)],val);
if(o->rank < o->son[d]->rank)
rotate(o,d);
}
maintain(o);
}
void erase(node *&o,int val)
{
if(o->val==val)
{
if(o->num>=2)
{
o->num--;
o->size--;
return;
}
else if(o->son[1]!=null&&o->son[0]!=null)
{
bool d=o->son[1]->rank >o->son[0]->rank;
rotate(o,d);
erase(o->son[d^1],val);//
}
else
{
node *t=o;
if(o->son[0]==null)o=o->son[1];
else o=o->son[0];
delete t;
}
}
else
erase(o->son[o->cmp(val)],val);
if(o!=null)maintain(o);
}
int checkrank(node *o,int val)
{
if(val==o->val)return o->son[0]->size+1;
bool d=o->cmp(val);
if(d)
return checkrank(o->son[d],val)+o->son[0]->size+o->num;
else return checkrank(o->son[d],val);
}
int checknum(node *o,int rank)
{
if(rank<=o->son[0]->size+o->num&&rank>=o->son[0]->size+1)return o->val;
if(rank<=o->son[0]->size)return checknum(o->son[0],rank);
else return checknum(o->son[1],rank-o->son[0]->size-o->num);
}
int getpre(node *o,int val)
{
if(o->val>=val)
{
if(o->son[0]!=null)return getpre(o->son[0],val);
return -(1<<30);
}
else
{
if(o->son[1]!=null)return max(o->val,getpre(o->son[1],val));
return o->val;
}
}
int getsuc(node *o,int val)
{
if(o->val<=val)
{
if(o->son[1]!=null)return getsuc(o->son[1],val);
return 1<<30;
}
else
{
if(o->son[0]!=null)return min(o->val,getsuc(o->son[0],val));
return o->val;
}
}
int main()
{
int n;
scanf("%d",&n);
for(int i=1;i<=n;i++)
{
int op,val;
scanf("%d%d",&op,&val);
if(op==1) insert(root,val);
else if(op==2) erase(root,val);
else if(op==3) printf("%d\n",checkrank(root,val));
else if(op==4) printf("%d\n",checknum(root,val));
else if(op==5) printf("%d\n",getpre(root,val));
else printf("%d\n",getsuc(root,val));
}
return 0;
}
闲得无聊又码了一遍splay
#include<iostream>
#include<cstring>
#include<cstdlib>
#include<cstdio>
#include<algorithm>
using namespace std;
struct node
{
int val,size,cnt;
node *fa,*son[2];
node(int val_=0,node* fa_=NULL)
{
val=val_;fa=fa_;
son[0]=son[1]=NULL;
size=cnt=1;
}
bool dir()
{
return this->fa->son[1]==this;
}
bool cmp(int val_)
{
return val_>val;
}
}*root;
void maintain(node *o)
{
o->size=o->cnt;
for(int i=0;i<2;i++)
if(o->son[i])
o->son[i]->fa=o,o->size+=o->son[i]->size;
}
node* rotate(node *o)
{
int d=o->dir();
node *p=o->fa;
p->son[d]=o->son[d^1];
o->son[d^1]=p;
o->fa=p->fa;
if(p->fa)p->fa->son[p->dir()]=o;
maintain(p),maintain(o);
}
node* splay(node *o,node *fin)
{
while(o->fa!=fin)
{
node *p=o->fa;
if(p->fa==fin)rotate(o);
else if(p->dir()==o->dir())rotate(p),rotate(o);
else rotate(o),rotate(o);
}
if(!fin)root=o;
}
void insert(node *&o,node *fa,int val)
{
if(!o)
{
o=new node(val,fa);
splay(o,NULL);
return;
}
if(val>o->val)insert(o->son[1],o,val);
else if(val<o->val)insert(o->son[0],o,val);
else o->size++,o->cnt++,splay(o,NULL);
//maintain(o);可以不需要
}
node *find(node *o,int val)
{
if(!o)return NULL;
if(o->val<val)return find(o->son[1],val);
else if(o->val>val)return find(o->son[0],val);
return o;
}
void erase(int val)
{
node *o=find(root,val);
splay(o,NULL);
o->size--;o->cnt--;
if(o->cnt)return;
if(!o->son[0]&&!o->son[1])root=NULL;
else if(o->son[0]&&o->son[1])
{
node *t=o->son[0];
t->fa=NULL;
while(t->son[1])t=t->son[1];
splay(t,NULL);
t->son[1]=o->son[1];
root=t;
maintain(t);
delete o;
}
else if(o->son[0])
{
o->son[0]->fa=NULL;
root=o->son[0];
delete o;
}
else
{
o->son[1]->fa=NULL;
root=o->son[1];
delete o;
}
}
int rank(node *o,int val)
{
if(o->val==val)return o->son[0]?o->son[0]->size+1:1;
if(o->val>val)return o->son[0]?rank(o->son[0],val):1;
else if(o->val<val)
return o->son[0]?o->son[0]->size+o->cnt+rank(o->son[1],val):o->cnt+rank(o->son[1],val);
}
int getnum(node *o,int r)
{
int left=o->son[0]?o->son[0]->size:0;
if(r>left&&r<=left+o->cnt)return o->val;
if(r<=left)return getnum(o->son[0],r);
return getnum(o->son[1],r-left-o->cnt);
}
int pre(node *o,int val)
{
int res;
while(o)
{
if(o->val>=val)o=o->son[0];
else res=o->val,o=o->son[1];
}
return res;
}
int suc(node *o,int val)
{
int res;
while(o)
{
if(o->val<=val)o=o->son[1];
else res=o->val,o=o->son[0];
}
return res;
}
int main()
{
int q;
scanf("%d",&q);
for(int i=1,op,tmp;i<=q;i++)
{
scanf("%d%d",&op,&tmp);
if(op==1) insert(root,NULL,tmp);
else if(op==2) erase(tmp);
else if(op==3) printf("%d\n",rank(root,tmp));
else if(op==4) printf("%d\n",getnum(root,tmp));
else if(op==5) printf("%d\n",pre(root,tmp));
else printf("%d\n",suc(root,tmp));
}
return 0;
}