平衡树——Treap
\qquad
这篇文章主要介绍用
T
r
e
a
p
Treap
Treap实现的平衡树其它的我不会
Treap原理
Treap = BST + Heap
∙
\bullet
∙先来介绍
B
S
T
BST
BST是什么
\qquad
B
i
n
a
r
y
S
e
r
a
h
T
r
e
e
Binary\quad Serah\quad Tree
BinarySerahTree,即二叉搜索树。其定义为左子树都小于根节点的值,右子树都大于根节点的值。
∙
B
S
T
\bullet BST
∙BST的中序遍历结果一定是从小到大的序列。
∙
\bullet
∙默认的,
B
S
T
BST
BST中结点的值互不相同,如果存在多个相同的值,可以在结点上开一个新变量
c
n
t
cnt
cnt来记录每个值出现的次数
B S T BST BST的操作
1.
1.
1.插入:根据待插入值进行递归插入即可
2.
2.
2.删除:叶节点可以直接删除,如果不是叶节点,则转化为叶节点删除
3.
3.
3.前驱和后继:因为
B
S
T
BST
BST的中序遍历是有序的,所以一个结点的前驱即其中序遍历的前驱,后继即其中序遍历的后继
4.
4.
4.找最值:沿着树遍历即可
5.
5.
5.求某个值的排名
6.
6.
6.求排名第
k
k
k的数是哪个
7.
7.
7.求比
x
x
x小的最值,
x
x
x不一定在树中出现
t
i
p
s
:
tips:
tips:因为最大值无前驱,最小值无后继,所以我们现在树中先加入一个极大值和极小值
H e a p Heap Heap堆的部分
堆的详解见另一篇博客浅谈堆及其应用
这里主要说堆在
T
r
e
a
p
Treap
Treap中的应用
为了维护
B
S
T
BST
BST深度在
log
n
\log n
logn级,可以证明随机地向
B
S
T
BST
BST中插入结点插入节点可以保持其深度的稳定性(我不会证明),这里使用堆来维护
B
S
T
BST
BST深度的稳定,下面定义一个
T
r
e
a
p
Treap
Treap(以大根堆为例)
struct Node{
int l,r; //左右孩子编号
int cnt; //该节点重复key值的个数
int key; //该节点BST中的key值
int val; //该节点堆中的值
int size; //从该点到叶节点的节点个数
}tr[MAXN];
∙
T
r
e
a
p
\bullet Treap
∙Treap同时是一个
B
S
T
BST
BST,也是一个堆
∙
\bullet
∙如果树中所有结点的
k
e
y
,
v
a
l
key,val
key,val值均不相同,则这棵树唯一,任意结点的
k
e
y
key
key值大于左节点,小于右节点,任意结点的
v
a
l
val
val值小于左右节点
下面介绍 T r e a p Treap Treap中的重要操作左旋,右旋
需要注意为了维护结点的
s
i
z
e
size
size,这里需要在每次旋转过后重新计算当前节点
s
i
z
e
size
size,
s
i
z
e
size
size的计算方法基本和线段树相同
维护
s
i
z
e
size
size
void pushup(int k){
tr[k].size = tr[tr[k].l].size + tr[tr[k].r].size + tr[k].cnt;
}
右旋
inline void right(int &p){
int q = tr[p].l;
tr[p].l = tr[q].r;
tr[q].r = p;
p = q;
pushup(tr[p].r);
pushup(p);
}
左旋
inline void left(int &p){
int q = tr[p].r;
tr[p].r = tr[q].l;
tr[q].l = p;
p = q;
pushup(tr[p].l);
pushup(p);
}
对于
T
r
e
a
p
Treap
Treap的操作和
B
S
T
BST
BST基本相同,主要要注意插入与删除操作
∙
\bullet
∙插入:按照
k
e
y
key
key值将目标点插入到树中,再根据其
v
a
l
val
val值进行左右旋转来达到让树高尽量平衡的目的
插入代码
void insert(int &p,int key){
if(!p) p = NEW(key);
else if(tr[p].key == key) tr[p].cnt ++;
else if(tr[p].key > key){
insert(tr[p].l,key);
if(tr[tr[p].l].val > tr[p].val) right(p);
}
else{
insert(tr[p].r,key);
if(tr[tr[p].r].val > tr[p].val) left(p);
}
pushup(p);
}
∙ \bullet ∙删除:将目标节点旋转至叶节点再删除
void remove(int &p,int key){
if(!p) return;
if(tr[p].key == key){ //删除当前节点
if(tr[p].cnt > 1) tr[p].cnt --; //有多个直接cnt--
else if(tr[p].l || tr[p].r){
if((!tr[p].r) || (tr[p].l && tr[tr[p].l].val > tr[tr[p].r].val)){ //满足左旋删除条件
right(p);
remove(tr[p].r,key);
}
else{ //否则右旋删除
left(p);
remove(tr[p].l,key);
}
} else p = 0;
} else if(tr[p].key > key) remove(tr[p].l,key);
else remove(tr[p].r,key);
pushup(p);
}
完整代码
#include<bits\stdc++.h>
using namespace std;
const int MAXN = 1e6 + 10,INF = 1e8;
int idx = 0,root;
struct Node{
int l,r; //左右孩子编号
int cnt; //该节点重复key值的个数
int key; //该节点BST中的key值
int val; //该节点堆中的值
int size; //从该点到叶节点的节点个数
}tr[MAXN];
inline int read(){
int n = 0,l = 1;
char c = getchar();
while(c < '0' || c > '9'){
if(c == '-') l = -1;
c = getchar();
}
while(c >= '0' && c <= '9'){
n = (n << 1) + (n << 3) + (c & 15);
c = getchar();
}
return n * l;
}
inline void pushup(int p){
tr[p].size = tr[tr[p].l].size + tr[tr[p].r].size + tr[p].cnt;
}
inline int NEW(int key){
tr[++ idx].key = key;
tr[idx].val = rand();
tr[idx].cnt = tr[idx].size = 1;
return idx;
}
inline void right(int &p){
int q = tr[p].l;
tr[p].l = tr[q].r;
tr[q].r = p;
p = q;
pushup(tr[p].r);
pushup(p);
}
inline void left(int &p){
int q = tr[p].r;
tr[p].r = tr[q].l;
tr[q].l = p;
p = q;
pushup(tr[p].l);
pushup(p);
}
//建树,提前加入一个极大值和极小值
inline void build(){
NEW(-INF);
NEW(INF);
root = 1;
tr[1].r = 2;
pushup(root);
if(tr[1].val < tr[2].val) left(root);
}
void insert(int &p,int key){
if(!p) p = NEW(key);
else if(tr[p].key == key) tr[p].cnt ++;
else if(tr[p].key > key){
insert(tr[p].l,key);
if(tr[tr[p].l].val > tr[p].val) right(p);
}
else{
insert(tr[p].r,key);
if(tr[tr[p].r].val > tr[p].val) left(p);
}
pushup(p);
}
void remove(int &p,int key){
if(!p) return;
if(tr[p].key == key){ //删除当前节点
if(tr[p].cnt > 1) tr[p].cnt --; //有多个直接cnt--
else if(tr[p].l || tr[p].r){
if((!tr[p].r) || (tr[p].l && tr[tr[p].l].val > tr[tr[p].r].val)){ //满足左旋删除条件
right(p);
remove(tr[p].r,key);
}
else{ //否则右旋删除
left(p);
remove(tr[p].l,key);
}
} else p = 0;
} else if(tr[p].key > key) remove(tr[p].l,key);
else remove(tr[p].r,key);
pushup(p);
}
//求值为key的排名
int get_key(int p,int key){
if(!p) return 0;
if(tr[p].key == key) return tr[tr[p].l].size + 1;
if(tr[p].key > key) return get_key(tr[p].l,key);
return tr[tr[p].l].size + tr[p].cnt + get_key(tr[p].r,key);
}
//求排名为rank的数
int get_rank(int p,int rank){
if(!p) return INF;
if(tr[tr[p].l].size >= rank) return get_rank(tr[p].l,rank);
if(tr[tr[p].l].size + tr[p].cnt >= rank) return tr[p].key;
return get_rank(tr[p].r,rank - tr[tr[p].l].size - tr[p].cnt);
}
//求前驱
int get_pre(int p,int key){
if(!p) return -INF;
if(tr[p].key >= key) return get_pre(tr[p].l,key);
return max(tr[p].key,get_pre(tr[p].r,key));
}
//求后继
int get_next(int p,int key){
if(!p) return INF;
if(tr[p].key <= key) return get_next(tr[p].r,key);
return min(tr[p].key,get_next(tr[p].l,key));
}
int main(){
build();
int n = read();
while(n --){
int opt = read(),x = read();
if(opt == 1) insert(root,x);
if(opt == 2) remove(root,x);
if(opt == 3) printf("%d\n",get_key(root,x) - 1);
if(opt == 4) printf("%d\n",get_rank(root,x + 1));
if(opt == 5) printf("%d\n",get_pre(root,x));
if(opt == 6) printf("%d\n",get_next(root,x));
}
return 0;
}