题目大意:
您需要写一种数据结构(可参考题目标题),来维护一些数,其中需要提供以下操作:
1. 插入x数
2. 删除x数(若有多个相同的数,因只删除一个)
3. 查询x数的排名(若有多个相同的数,因输出最小的排名)
4. 查询排名为x的数
5. 求x的前驱(前驱定义为小于x,且最大的数)
6. 求x的后继(后继定义为大于x,且最小的数)
题解:
一道平衡树的模板题,我们可以用Treap实现
只需要在每个节点多添加一个域cnt记录该节点所对应的数出现的数量,以及以这颗点为根的子树的大小,就可以查询x数的排名和排名x的数
然后递归实现Treap就可以了
代码:
#include<bits/stdc++.h>
#define INF 0x7fffffff
#define N 100005
using namespace std;
struct Treap{
int l,r;
int val,dat;
int cnt,size;
} a[N];
int tot,root,n;
int New(int val){
a[++tot].val = val;
a[tot].dat = rand();
a[tot].cnt = a[tot].size = 1;
return tot;
}
void Update(int p){
a[p].size = a[a[p].l].size + a[a[p].r].size + a[p].cnt;
}
void Build(){
New(-INF); New(INF);
root = 1;
a[root].r = 2;
Update(root);
}
void zig(int &p){
int q = a[p].l;
a[p].l = a[q].r;
a[q].r = p;
p = q;
Update(a[p].r);
Update(p);
}
void zag(int &p){
int q = a[p].r;
a[p].r = a[q].l;
a[q].l = p;
p = q;
Update(a[p].l);
Update(p);
}
void Insert(int &p,int val){
if (p == 0){
p = New(val);
return;
}
if (val == a[p].val){
a[p].cnt++;
Update(p);
return;
}
if (val < a[p].val){
Insert(a[p].l,val);
if (a[p].dat < a[a[p].l].dat) zig(p);
}
else{
Insert(a[p].r,val);
if (a[p].dat < a[a[p].r].dat) zag(p);
}
Update(p);
}
void Remove(int &p, int val){
if (!p) return;
if (val == a[p].val){
if (a[p].cnt > 1){
a[p].cnt--;
Update(p);
return;
}
if (a[p].l || a[p].r){
if (a[p].r == 0 || a[a[p].l].dat > a[a[p].r].dat){
zig(p); Remove(a[p].r,val);
} else { zag(p),Remove(a[p].l,val); }
Update(p);
} else p = 0;
return;
}
val < a[p].val ? Remove(a[p].l,val) : Remove(a[p].r,val);
Update(p);
}
int GetRankByVal(int p,int val){
if (!p) return 0;
if (val == a[p].val) return a[a[p].l].size + 1;
if (val < a[p].val) return GetRankByVal(a[p].l,val);
return GetRankByVal(a[p].r,val) + a[a[p].l].size + a[p].cnt;
}
int GetValByRank(int p,int rank){
if (!p) return INF;
if (a[a[p].l].size >= rank) return GetValByRank(a[p].l,rank);
if (a[a[p].l].size + a[p].cnt>=rank) return a[p].val;
return GetValByRank(a[p].r,rank - a[a[p].l].size - a[p].cnt);
}
int GetPre(int val){
int ans = 1;
int p = root;
while (p){
if (val == a[p].val){
if (a[p].l > 0){
p = a[p].l;
while (a[p].r > 0) p = a[p].r;
ans = p;
}
break;
}
if (a[p].val < val && a[p].val > a[ans].val) ans = p;
p = val < a[p].val ? a[p].l : a[p].r;
}
return a[ans].val;
}
int GetNext(int val){
int ans = 2;
int p = root;
while (p){
if (val == a[p].val){
if (a[p].r > 0){
p=a[p].r;
while (a[p].l > 0) p = a[p].l;
ans = p;
}
break;
}
if (a[p].val > val && a[p].val < a[ans].val) ans = p;
p = val < a[p].val ? a[p].l : a[p].r;
}
return a[ans].val;
}
int main(){
Build();
scanf("%d",&n);
while (n--){
int s,x;
scanf("%d%d",&s,&x);
if (s == 1) Insert(root,x);
if (s == 2) Remove(root,x);
if (s == 3) printf("%d\n",GetRankByVal(root,x) - 1);
if (s == 4) printf("%d\n",GetValByRank(root,x+1));
if (s == 5) printf("%d\n",GetPre(x));
if (s == 6) printf("%d\n",GetNext(x));
}
}