P1728 普通平衡树
时间: 1000ms / 空间: 131072KiB / Java类名: Main
背景
此为平衡树系列第一道:普通平衡树
描述
您需要写一种数据结构(可参考题目标题),来维护一些数,其中需要提供以下操作:
1. 插入x数
2. 删除x数(若有多个相同的数,因只删除一个)
3. 查询x数的排名(若有多个相同的数,因输出最小的排名)
4. 查询排名为x的数
5. 求x的前驱(前驱定义为小于x,且最大的数)
6. 求x的后继(后继定义为大于x,且最小的数)
1. 插入x数
2. 删除x数(若有多个相同的数,因只删除一个)
3. 查询x数的排名(若有多个相同的数,因输出最小的排名)
4. 查询排名为x的数
5. 求x的前驱(前驱定义为小于x,且最大的数)
6. 求x的后继(后继定义为大于x,且最小的数)
输入格式
第一行为n,表示操作的个数,下面n行每行有两个数opt和x,opt表示操作的序号(1<=opt<=6)
输出格式
对于操作3,4,5,6每行输出一个数,表示对应答案
测试样例1
输入
8
1 10
1 20
1 30
3 20
4 2
2 10
5 25
6 -1
输出
2
20
20
20
备注
n<=100000 所有数字均在-10^7到10^7内
题解:平衡树的基础操作。
Treap代码:
#include<algorithm>
#include<iostream>
#include<cstdio>
#include<cstring>
#include<vector>
#include<math.h>
#include<stdlib.h>
#define nn 210000
#define eps 1e-8
#define inff 0x7fffffff
#define lson rt<<1,l,m
#define rson rt<<1|1,m+1,r
#define mod 20071027
using namespace std;
typedef long long LL;
typedef unsigned long long LLU;
struct multi_treap
{
struct node
{
node *ch[2];
int r,v;
int cnt,sum;
int cmp(int x)
{
if(v==x)
return -1;
return x<v?0:1;
}
};
node *root;//treap的根节点
int numof(node* o)//以o为根的树的结点数目
{
if(o==NULL)
return 0;
return o->sum;
}
void update(node* o)//更新以该节点为根的数的个数
{
if(o==NULL)
return ;
o->sum=numof(o->ch[0])+numof(o->ch[1])+o->cnt;
}
void Rotate(node* &o,int d)//d为0左旋,d为1右旋
{
node *k=o->ch[d^1];
o->ch[d^1]=k->ch[d];
k->ch[d]=o;
update(o);
update(k);
o=k;
}
void Insert(node* &o,int x)//插入x
{
if(o==NULL)
{
o=new node();
o->v=x,o->r=rand()*rand();
o->cnt=o->sum=1;
o->ch[0]=o->ch[1]=NULL;
return ;
}
int d=o->cmp(x);
if(d==-1)
{
o->cnt++;//注释掉这两句话就是关闭元素可重复的功能
o->sum++;
return ;
}
Insert(o->ch[d],x);
if(o->ch[d]->r>o->r)
{
Rotate(o,d^1);
}
update(o);
}
void Remove(node* &o,int x)//删除x
{
if(o==NULL)
return ;
int d=o->cmp(x);
if(d==-1)
{
if(o->cnt==1)
{
if(o->ch[0]==NULL)
{
node *k=o;
o=o->ch[1];
delete k;
}
else if(o->ch[1]==NULL)
{
node *k=o;
o=o->ch[0];
delete k;
}
else
{
int d2=o->ch[0]->r>o->ch[1]->r?1:0;
Rotate(o,d2);
Remove(o->ch[d2],x);
}
}
else
o->cnt--;
}
else
Remove(o->ch[d],x);
update(o);
}
int Rank(node *o,int x)//查询x的排名
{
int ans=0;
while(o!=NULL)
{
if(o->v>x)
{
o=o->ch[0];
}
else if(o->v<x)
{
ans+=numof(o->ch[0])+o->cnt;
o=o->ch[1];
}
else
{
ans+=numof(o->ch[0])+1;
break;
}
}
return ans;
}
int Kth(node *o,int x)//查询排名为x的数
{
if(o==NULL)
return -1;
int ix=numof(o->ch[0]);
if(ix>=x)
return Kth(o->ch[0],x);
ix=x-ix-o->cnt;
if(ix<=0)
return o->v;
return Kth(o->ch[1],ix);
}
int Pre(node* o,int x)//查询x的前驱
{
int ans=-1;
while(o!=NULL)
{
if(o->v<x)
{
ans=o->v;
o=o->ch[1];
}
else
o=o->ch[0];
}
return ans;
}
int Suc(node* o,int x)//查询x的后继
{
int ans=-1;
while(o!=NULL)
{
if(o->v>x)
{
ans=o->v;
o=o->ch[0];
}
else
o=o->ch[1];
}
return ans;
}
void De(node* o)//清空
{
if(o==NULL)
return ;
De(o->ch[0]);
De(o->ch[1]);
delete o;
}
}tp;
int main()
{
int n;
int x,y,i;
while(scanf("%d",&n)!=EOF)
{
tp.root=NULL;
for(i=1;i<=n;i++)
{
scanf("%d%d",&x,&y);
if(x==1)
{
tp.Insert(tp.root,y);
}
else if(x==2)
{
tp.Remove(tp.root,y);
}
else if(x==3)
{
printf("%d\n",tp.Rank(tp.root,y));
}
else if(x==4)
{
printf("%d\n",tp.Kth(tp.root,y));
}
else if(x==5)
{
printf("%d\n",tp.Pre(tp.root,y));
}
else
printf("%d\n",tp.Suc(tp.root,y));
}
tp.De(tp.root);
}
return 0;
}
Splay
#include<stdio.h>
#include<iostream>
#include<algorithm>
const int inf=0x3fffffff;
using namespace std;
struct node
{
int val;
int num;
int sum;
node* pre;
node* ch[2];
}*root;
void display(node* o)
{
if(o==NULL)
return ;
display(o->ch[0]);
cout<<o->val<<endl;
display(o->ch[1]);
}
int numof(node* o)
{
if(o==NULL)
return 0;
return o->sum;
}
void update(node* o)
{
if(o==NULL)
return ;
o->sum=o->num+numof(o->ch[0])+numof(o->ch[1]);
}
void Rotate(node* o)
{
node* tem=o->pre;
int d;
if(tem->ch[0]==o) d=0;
else d=1;
tem->ch[d]=o->ch[d^1];
if(o->ch[d^1]!=NULL)
o->ch[d^1]->pre=tem;
if(tem->pre!=NULL)
{
if(tem->pre->ch[0]==tem)
tem->pre->ch[0]=o;
else
tem->pre->ch[1]=o;
}
o->pre=tem->pre;
o->ch[d^1]=tem;
tem->pre=o;
update(tem);
//update(o);伸展完以后再更新,减少常数
}
void Splay(node* o,node* f)
{
node* x;
node* y;
while(o->pre!=f)
{
if(o->pre->pre==f)
Rotate(o);
else
{
x=o->pre;
y=x->pre;
int d1,d2;
if(y->ch[0]==x) d1=0;
else d1=1;
if(x->ch[0]==o) d2=0;
else d2=1;
if(d1==d2)
{
Rotate(x);
Rotate(o);
}
else
{
Rotate(o);
Rotate(o);
}
}
}
update(o);
if(f==NULL)
root=o;
}
void Insert(node* &o,node* pre,int val)
{
if(o==NULL)
{
o=new node;
o->val=val;
o->sum=o->num=1;
o->pre=pre;
o->ch[0]=o->ch[1]=NULL;
Splay(o,NULL);
return ;
}
if(val==o->val)
{
o->num++;
o->sum++;
Splay(o,NULL);//一定要先更新,在伸展
}
else if(val<o->val)
Insert(o->ch[0],o,val);
else
Insert(o->ch[1],o,val);
}
node* Pre(node* o,int val)
{
node* re;
while(o!=NULL)
{
if(o->val<val)
{
re=o;
o=o->ch[1];
}
else
o=o->ch[0];
}
return re;
}
node* Suc(node* o,int val)
{
node* re;
while(o!=NULL)
{
if(o->val>val)
{
re=o;
o=o->ch[0];
}
else
o=o->ch[1];
}
return re;
}
void Remove(node* o,int val)
{
if(o==NULL)
return ;
if(val==o->val)
{
node* x=Pre(root,val);
node* y=Suc(root,val);
Splay(x,NULL);
Splay(y,x);
o->num--;
o->sum--;
if(o->num==0)
{
delete y->ch[0];
y->ch[0]=NULL;
}
Splay(y,NULL);
}
else if(val<o->val)
Remove(o->ch[0],val);
else
Remove(o->ch[1],val);
}
int Rank(node* o,int val)
{
if(o==NULL)
return 0;
if(val==o->val)
{
return numof(o->ch[0]);
}
else if(val<o->val)
{
return Rank(o->ch[0],val);
}
else
return numof(o->ch[0])+o->num+Rank(o->ch[1],val);
}
int Kth(node* o,int k)
{
// if(o==NULL)
// return 0;
if(numof(o->ch[0])>=k)
return Kth(o->ch[0],k);
else if(numof(o->ch[0])+o->num>=k)
return o->val;
else
return Kth(o->ch[1],k-numof(o->ch[0])-o->num);
}
void Clear(node* o)
{
if(o==NULL)
return ;
Clear(o->ch[0]);
Clear(o->ch[1]);
delete o;
}
int main()
{
int i,d,x;
int n;
while(scanf("%d",&n)!=EOF)
{
root=NULL;
Insert(root,NULL,-inf);
Insert(root,NULL,inf);
for(i=1;i<=n;i++)
{
scanf("%d%d",&d,&x);
if(d==1)
{
Insert(root,NULL,x);
}
else if(d==2)
{
Remove(root,x);
}
else if(d==3)
{
printf("%d\n",Rank(root,x));
}
else if(d==4)
{
printf("%d\n",Kth(root,x+1));
}
else if(d==5)
{
printf("%d\n",Pre(root,x)->val);
}
else
{
printf("%d\n",Suc(root,x)->val);
}
}
Clear(root);
}
return 0;
}