平衡树-Treap
2021年8月6日
什么是平衡树?
平衡树是指任意节点左右子树高度差都小于等于1的二叉树。
平衡树干什么?
平衡树对序列的排序,寻找元素的位置有很方便的操作
算法原理
建树
Treap通过随机赋予节点一个权值进行建树,由此大概率避免了建树成链的情况。
【1】、初始化树
初始化时,需要首先新建key值为正负无穷的两个节点,作为树的两端。随后依据随机赋予的权值进行平衡操作。
void built() {
get_node(-INF), get_node(INF);
root = 1; tre[1].r = 2;
pushup(root);
if (tre[1].val < tre[2].val)lez(root);
}
【2】、新增节点
新增节点时,将其权值赋予一个随机数,随后初始化当前节点数目以及子节点数目。
int get_node(int key){
tre[++idx].key=key;
tre[idx].val=rand();
tre[idx].cnt=tre[idx].size=1;
return idx;
}
【3】、更新点值
对于每次操作之后,树的形态发生变化,其节点内容也要发生变化,因此需要更新其子节点数目。
void pushup(int p) {//更新值
tre[p].size = tre[tre[p].l].size + tre[tre[p].r].size + tre[p].cnt;
}
【4】、加入新值
加入新节点,首先对比其key值,依据大小进行左右遍历.
因加入新节点时,可能有大权值在下的情况,因此需要对节点进行旋转维护。
void insert(int& p, int key) {
if (!p)p = get_node(key);
else if (key == tre[p].key)tre[p].cnt++;
else if (key > tre[p].key) {
insert(tre[p].r, key);
if (tre[tre[p].r].val > tre[p].val)lez(p);
}
else {
insert(tre[p].l, key);
if (tre[tre[p].l].val > tre[p].val)riz(p);
}
pushup(p);
}
【5】、删除值
依然根据key值进行查找,查找到的节点若有多个相同值,则自减1,否则删除节点。
删除节点时,若无任何子树,则直接删除,否则,将其旋转至最底层,然后删除。
旋转规则:
当其右子树为空或左子树权值更大时,右旋此节点(保证权值大的节点在上面),否则左旋此节点。因其上约束条件约定一定有子树,因此若不满足右旋条件,一定是左子树为空或右子树权值更大。
void remove(int& p, int key) {
if (!p)return;
if (tre[p].key == key) {
if (tre[p].cnt > 1)tre[p].cnt--;
else if (tre[p].l || tre[p].r)
if (!tre[p].r || tre[tre[p].l].val > tre[tre[p].r].val) {riz(p); remove(tre[p].r, key);}
else { lez(p); remove(tre[p].l, key); }
else p = 0;
}
else if (tre[p].key > key) remove(tre[p].l, key);
else remove(tre[p].r, key);
pushup(p);
}
【6】、左旋,右旋
左旋:
首先将节点的右子节点更新为右子节点的左子节点,再将右子树的左子节点指向原根节点,随后将根节点更新为右子节点,更新新的左子树与根节点的size值,左旋完成。
右旋:
首先将节点的左子节点更新为左子节点的右子节点,再将左子树的右子节点指向原根节点,随后将根节点更新为左子节点,更新新的右子树与根节点的size值,左旋完成。
左右旋的操作完全相反。
void lez(int& p) {//左旋
int q = tre[p].r;
tre[p].r = tre[q].l; tre[q].l = p; p = q;
pushup(tre[p].l); pushup(p);
}
void riz(int& p) {//右旋
int q = tre[p].l;
tre[p].l = tre[q].r; tre[q].r = p; p = q;
pushup(tre[p].r); pushup(p);
}
为何如此旋转?
以右旋为例,如图所示:
若要将左子节点当成根节点,必须将其右子树拿开之后才能使原来的根节点成为新的右子树,此时相当于原根节点少了一个左子树,原左子节点的右子节点没有了父节点。因左子树的所有值都比根节点小,因此左子节点的右子节点一定比根节点小,因此可以将原根节点当成原左子节点的右子节点的父节点。
以上操作之后,就可旋转为:
由此不会改变树的中序遍历顺序。
查询
【1】、根据值寻找排名
注:此排名代表有多少比key小的数(加上其本身1个数)
若查询到相等的节点,则返回其左子树的值加上它本身。
若值比当前节点key值小,需向左查询,此时因为当前节点key更大,因此不记录排名,直接返回左子树的值即可。
若值比当前值更大,就需要向右子树查找,但是当前节点以及其左子树是要记录到排名中,因此返回左子树的大小,节点本身数量以及右子树查询到的值。
int getrk(int p, int key) {
if (!p)return 0;
if (tre[p].key == key)return tre[tre[p].l].size + 1;
if (tre[p].key > key)return getrk(tre[p].l, key);
else return tre[tre[p].l].size + tre[p].cnt + getrk(tre[p].r, key);
}
【2】、根据排名找值
若查询到的节点的左子树大小要比所求排名大,则需向左搜索。
若查询到的节点的左子树大小更小但是加上其节点本身大于等于所需排名,则结果就是当前节点的值。
否则,向右子树查找,每次所需排名需减去当前点及其左子树的大小。
int getkr(int p, int rank) {
if (!p)return INF;
if (tre[tre[p].l].size >= rank)return getkr(tre[p].l, rank);
if (tre[tre[p].l].size + tre[p].cnt >= rank)return tre[p].key;
return getkr(tre[p].r, rank - tre[tre[p].l].size - tre[p].cnt);
}
【3】、查询序列中严格比某值小的最大值
首先找到某值的左子节点,之后向右子树遍历,知道找到最大值。
int minxk(int p, int key) {
if (!p)return -INF;
if (tre[p].key >= key)return minxk(tre[p].l, key);
else return max(tre[p].key, minxk(tre[p].r, key));
}
【4】、查询序列中严格比某值大的最小值
与【3】相反
int maxnk(int p, int key) {
if (!p)return INF;
if (tre[p].key <= key)return maxnk(tre[p].r, key);
else return min(tre[p].key, maxnk(tre[p].l, key));
}
模板例题
ACwing \253. 普通平衡树
您需要写一种数据结构(可参考题目标题),来维护一些数,其中需要提供以下操作:
- 插入数值 x。
- 删除数值 x(若有多个相同的数,应只删除一个)。
- 查询数值 x 的排名(若有多个相同的数,应输出最小的排名)。
- 查询排名为 x 的数值。
- 求数值 x 的前驱(前驱定义为小于 x 的最大的数)。
- 求数值 x 的后继(后继定义为大于 x 的最小的数)。
注意: 数据保证查询的结果一定存在。
输入格式
第一行为 n,表示操作的个数。
接下来 n 行每行有两个数 opt 和 xx,opt 表示操作的序号(1≤opt≤6)。
输出格式
对于操作 3,4,5,6 每行输出一个数,表示对应答案。
数据范围
1≤n≤100000,所有数均在 −1e7到 1e7 内。
输入样例:
8
1 10
1 20
1 30
3 20
4 2
2 10
5 25
6 -1
输出样例:
2
20
20
20
代码
#include<iostream>
#include<algorithm>
#include<stdio.h>
#include<string>
#include<string.h>
using namespace std;
const int N = 100005, INF = 0x3f3f3f3f;
int n;
struct P {
int l, r;
int key;
int cnt, size, val;
}tre[N];
int root, idx;
void pushup(int p) {//更新值
tre[p].size = tre[tre[p].l].size + tre[tre[p].r].size + tre[p].cnt;
}
int get_node(int key) {//建立节点
tre[++idx].key = key;
tre[idx].val = rand();
tre[idx].cnt = tre[idx].size = 1;
return idx;
}
void lez(int& p) {//左旋
int q = tre[p].r;
tre[p].r = tre[q].l; tre[q].l = p; p = q;
pushup(tre[p].l); pushup(p);
}
void riz(int& p) {//右旋
int q = tre[p].l;
tre[p].l = tre[q].r; tre[q].r = p; p = q;
pushup(tre[p].r); pushup(p);
}
void built() {//建树
get_node(-INF), get_node(INF);
root = 1; tre[1].r = 2;
pushup(root);
if (tre[1].val < tre[2].val)lez(root);
}
void insert(int& p, int key) {//加点
if (!p)p = get_node(key);
else if (key == tre[p].key)tre[p].cnt++;
else if (key > tre[p].key) {
insert(tre[p].r, key);
if (tre[tre[p].r].val > tre[p].val)lez(p);
}
else {
insert(tre[p].l, key);
if (tre[tre[p].l].val > tre[p].val)riz(p);
}
pushup(p);
}
void remove(int& p, int key) {//删点
if (!p)return;
if (tre[p].key == key) {
if (tre[p].cnt > 1)tre[p].cnt--;
else if (tre[p].l || tre[p].r) {
if (!tre[p].r || tre[tre[p].l].val > tre[tre[p].r].val) {
riz(p); remove(tre[p].r, key);
}
else {
lez(p); remove(tre[p].l, key);
}
}
else p = 0;
}
else if (tre[p].key > key) {
remove(tre[p].l, key);
}
else remove(tre[p].r, key);
pushup(p);
}
int getrk(int p, int key) {//找排名
if (!p)return 0;
if (tre[p].key == key)return tre[tre[p].l].size + 1;
if (tre[p].key > key)return getrk(tre[p].l, key);
else return tre[tre[p].l].size + tre[p].cnt + getrk(tre[p].r, key);
}
int getkr(int p, int rank) {//找值
if (!p)return INF;
if (tre[tre[p].l].size >= rank)return getkr(tre[p].l, rank);
if (tre[tre[p].l].size + tre[p].cnt >= rank)return tre[p].key;
return getkr(tre[p].r, rank - tre[tre[p].l].size - tre[p].cnt);
}
int minxk(int p, int key) {//求前驱
if (!p)return -INF;
if (tre[p].key >= key)return minxk(tre[p].l, key);
else return max(tre[p].key, minxk(tre[p].r, key));
}
int maxnk(int p, int key) {//求后继
if (!p)return INF;
if (tre[p].key <= key)return maxnk(tre[p].r, key);
else return min(tre[p].key, maxnk(tre[p].l, key));
}
int main() {
built();
cin >> n;
while (n--) {
int op, x;
cin >> op >> x;
if (op == 1)insert(root, x);
else if (op == 2)remove(root, x);
else if (op == 3)cout << getrk(root, x) - 1 << endl;
else if (op == 4)cout << getkr(root, x + 1) << endl;
else if (op == 5)cout << minxk(root, x) << endl;
else cout << maxnk(root, x) << endl;
}
return 0;
}