题目描述
您需要写一种数据结构(可参考题目标题),来维护一些数,其中需要提供以下操作:
插入 xx 数
删除 xx 数(若有多个相同的数,因只删除一个)
查询 xx 数的排名(排名定义为比当前数小的数的个数 +1+1 。若有多个相同的数,因输出最小的排名)
查询排名为 xx 的数
求 xx 的前驱(前驱定义为小于 xx ,且最大的数)
求 xx 的后继(后继定义为大于 xx ,且最小的数)
输入输出格式
输入格式:
第一行为 nn ,表示操作的个数,下面 nn 行每行有两个数 optopt 和 xx , optopt 表示操作的序号( 1 \leq opt \leq 6 1≤opt≤6 )
输出格式:
对于操作 3,4,5,63,4,5,6 每行输出一个数,表示对应答案
思路
treap模版
只要了解旋转,一切好说
代码
#include <iostream>
#include <cstdio>
#include <ctime>
#include <cstdlib>
const int maxn=1e5+7;
const int inf=0x3f3f3f3f;
using namespace std;
struct node{
int l,r;
int key,data;
int size;
}t[maxn];
int n,op,x,cnt,root;
void updata(int x)
{
t[x].size=t[t[x].l].size+t[t[x].r].size+1;
}
void rttr(int &x)
{
int y=t[x].l;
t[x].l=t[y].r;
t[y].r=x;
updata(x); updata(y);
x=y;
}
void rttl(int &x)
{
int y=t[x].r;
t[x].r=t[y].l;
t[y].l=x;
updata(x); updata(y);
x=y;
}
void ins(int &x,int k)
{
if (x==0)
{
x=++cnt;
t[cnt].size=1;
t[cnt].data=k;
t[cnt].key=rand();
return;
}
if (k<=t[x].data)
{
ins(t[x].l,k);
if (t[t[x].l].key>t[x].key) rttr(x);
}
else
{
ins(t[x].r,k);
if (t[t[x].r].key>t[x].key) rttl(x);
}
updata(x);
}
int get_rank(int x,int k)
{
if (x==0) return 0;
if (k>t[x].data) return t[t[x].l].size+get_rank(t[x].r,k)+1;
else return get_rank(t[x].l,k);
}
int get_val(int x,int k)
{
if (t[t[x].l].size+1==k) return t[x].data;
if (k<t[t[x].l].size+1) return get_val(t[x].l,k);
else return get_val(t[x].r,k-t[t[x].l].size-1);
}
int get_pre(int k)
{
int x=root;
int ans=-inf;
while (x)
{
if (k>t[x].data)
{
if (t[x].data>=ans) ans=t[x].data;
x=t[x].r;
}
else x=t[x].l;
}
return ans;
}
int get_next(int k)
{
int x=root;
int ans=inf;
while (x)
{
if (k<t[x].data)
{
if (t[x].data<=ans) ans=t[x].data;
x=t[x].l;
}
else x=t[x].r;
}
return ans;
}
void del(int &x,int k)
{
if (x==0) return;
if (t[x].data==k)
{
if ((t[x].l) || (t[x].r))
{
if ((t[x].l) && ((!t[x].r) || (t[t[x].l].key<t[t[x].r].key)))
{
rttr(x); del(t[x].r,k);
}
else
{
rttl(x); del(t[x].l,k);
}
updata(x);
}
else x=0;
return;
}
if (k<t[x].data) del(t[x].l,k);
else del(t[x].r,k);
updata(x);
}
int main()
{
//freopen("data.in","r",stdin);
//freopen("data.out","w",stdout);
scanf("%d",&n);
for (int i=1;i<=n;i++)
{
scanf("%d%d",&op,&x);
if (op==1) ins(root,x);
if (op==2) del(root,x);
if (op==3) printf("%d\n",get_rank(root,x)+1);
if (op==4) printf("%d\n",get_val(root,x));
if (op==5) printf("%d\n",get_pre(x));
if (op==6) printf("%d\n",get_next(x));
}
}