题目链接
题意:
您需要写一种数据结构(可参考题目标题),来维护一些数,其中需要提供以下操作:
插入数值 x。
删除数值 x(若有多个相同的数,应只删除一个)。
查询数值 x 的排名(若有多个相同的数,应输出最小的排名)。
查询排名为 x 的数值。
求数值 x 的前驱(前驱定义为小于 x 的最大的数)。
求数值 x 的后继(后继定义为大于 x 的最小的数)。
注意: 数据保证查询的结果一定存在。
分析:
现在先来介绍一下平衡树是什么吧,其实平衡树就是一棵二叉搜索树,对于学过数据结构的童鞋对这个二叉搜索树的概念应该并不陌生,他就是一棵二叉树,根节点的权值比左右孩子节点的权值都大,所以说这棵树要是完全二叉树的话,找一个节点的复杂度就是O(logn),但是随着插入顺序的不同,二叉搜索树的形态也会不同,这棵树的形态在最极端的情况下就是一个长链,那么现在查询一个节点的时间复杂度就成了O(N)了,那么怎么优化呢,就引入了平衡树的概念,现在来介绍两个概念,左旋和右旋,下面看一下这张图:
左边这棵树变成右边这棵树就是左旋,反之则是右旋,意思就是根节点向哪里走,如果根节点向左走就是左旋,向右走就是右旋,那么这样就在不改变二叉树中序遍历的前提下使得二叉树更趋近于完全二叉树,在这个题目中的具体体现就是这个val,他是一个随机数,所以说随机数的话一般不会太拉跨,因为总不能随机一个递增的序列或者是随机一个递减的序列吧hh,那么咱们现在来看这个题,insert的话就是插入一个节点,如果已经有了这个节点的话就加1,remove的话就是删除这个值代表的节点,如果这个值所在节点的数量是1的话,删完就没有了,那么现在咱们就需要左旋或者右旋把他旋到最底部的叶子节点,那么删了它的话就不会牵一发而动全身了,剩下的就都是查找操作了,都挺简单的,我先说一下结构体里面的元素代表的具体含义,key就是这个节点的值,val就是这个节点的权重,cnt就是这个节点的值的数量,size就是这个节点以及他的子树中的节点的值的数量的和,下面我就不过多的赘述了,请各位读者结合函数的自身含义进行理解,还要注意有两个哨兵,正无穷和负无穷,这样可以防止数组越界和减少代码量,下面请看代码:
#include<iostream>
#include<cstdio>
#include<cstdlib>
using namespace std;
const int N = 100010,INF = 2147483647;
struct node{
int l,r,cnt,size,val,key;
};
int idx,root;
node tr[N];
int get_node(int key){
tr[++idx].key = key;
tr[idx].val = rand();
tr[idx].cnt = tr[idx].size = 1;
tr[idx].l = tr[idx].r = 0;
return idx;
}
void pu(int p){
tr[p].size = tr[tr[p].l].size + tr[tr[p].r].size + tr[p].cnt;
}
void zig(int &p){
int q = tr[p].l;
tr[p].l = tr[q].r;tr[q].r = p;p = q;
pu(tr[p].r);pu(p);
}
void zag(int &p){
int q = tr[p].r;
tr[p].r = tr[q].l;tr[q].l = p;p = q;
pu(tr[p].l);pu(p);
}
void build(){
root = get_node(INF);
int d = get_node(-INF);
tr[root].l = d;
if(tr[1].val < tr[2].val) zag(root);
pu(root);
}
void insert(int &p,int x){
if(!p) p = get_node(x);
else if(tr[p].key == x){
tr[p].cnt++;
}
else if(tr[p].key > x){
insert(tr[p].l,x);
if(tr[tr[p].l].val > tr[p].val) zig(p);
}
else if(tr[p].key < x){
insert(tr[p].r,x);
if(tr[tr[p].r].val > tr[p].val) zag(p);
}
pu(p);
}
void remove(int &p,int x){
if(!p) return;
if(tr[p].key == x){
if(tr[p].cnt > 1) tr[p].cnt--;
else if(tr[p].l || tr[p].r){
if(!tr[p].r || tr[tr[p].l].val > tr[tr[p].r].key){
zig(p);
remove(tr[p].r,x);
}
else{
zag(p);
remove(tr[p].l,x);
}
}
else p = 0;
}
else if(tr[p].key > x) remove(tr[p].l,x);
else remove(tr[p].r,x);
pu(p);
}
int get_rank_by_val(int p,int key){//找到
if(!p) return 0;
if(tr[p].key == key) return tr[tr[p].l].size + 1;
else if(tr[p].key > key) return get_rank_by_val(tr[p].l,key);
else return tr[tr[p].l].size + tr[p].cnt + get_rank_by_val(tr[p].r,key);
}
int get_val_by_rank(int p,int rank){
if(!p) return 0;
if(tr[tr[p].l].size >= rank) return get_val_by_rank(tr[p].l,rank);
else if(tr[tr[p].l].size + tr[p].cnt >= rank) return tr[p].key;
else return get_val_by_rank(tr[p].r,rank-tr[p].cnt-tr[tr[p].l].size);
}
int get_prev(int p,int x){
if(!p) return -INF;
if(tr[p].key >= x) return get_prev(tr[p].l,x);
else return max(tr[p].key,get_prev(tr[p].r,x));
}
int get_bacv(int p,int x){
if(!p) return INF;
if(tr[p].key <= x) return get_bacv(tr[p].r,x);
else return min(tr[p].key,get_bacv(tr[p].l,x));
}
int main(){
int n;
scanf("%d",&n);
build();
for(int i=0;i<n;i++){
int op,x;
scanf("%d%d",&op,&x);
if(op == 1) insert(root,x);
else if(op == 2) remove(root,x);
else if(op == 3) printf("%d\n",get_rank_by_val(root,x)-1);
else if(op == 4) printf("%d\n",get_val_by_rank(root,x+1));
else if(op == 5) printf("%d\n",get_prev(root,x));
else if(op == 6) printf("%d\n",get_bacv(root,x));
}
return 0;
}