概念
T r e a p Treap Treap 叫树堆,是一种平衡二叉树,它为每一个节点加上了一个随机数,使其满足堆的性质(默认大根堆),而节点的值又满足二叉搜索树的性质, T r e a p Treap Treap 能够实现随机平衡,期望的时间复杂度是 O ( l o g N ) O(logN) O(logN),相对于其他类型的平衡二叉树:
T r e a p Treap Treap 优点:
- 实现的逻辑比较简单,写法简单
- 支持分裂合并,但要涉及可持久化,这种情况下伸展树更常用
- 比 A V L AVL AVL 树常数小
- 可以充当 R a n k − T r e e Rank-Tree Rank−Tree,这是 S T L − s e t STL-set STL−set 无法实现的,常常用在线段树套平衡树中
T r e a p Treap Treap 缺点:
- 相对于红黑之类的平衡树,速度比较慢
- 具有随机性,不稳定
写法
本文介绍
T
r
e
a
p
Treap
Treap 的几种基本操作:
1. pushup操作
当前节点时,用左右子树进行更新
//从下至上更新节点值
inline void pushup(int k) //对以k为根节点进行sum更新 ,注意默认tr[k].sum=tr[k].cnt=0
{
tr[k].sum=tr[tr[k].l].sum+tr[tr[k].r].sum+tr[k].cnt;
}
2. 旋转操作
以右旋为例子:
//右旋转zig 发现具有连锁
void zig(int& p) //注意一定是引用,引用的是一个数组值
{
int tp=tr[p].l;
tr[p].l=tr[tp].r;
tr[tp].r=p;
p=tp;
pushup(tr[p].r);
pushup(p);
}
3.插入函数
首先不断查找,若查找到直接递增次数;若没查找到,则在叶子节点上插入,尾递归返回时,每一步都要判断堆的性质,保持大根堆性质,且要 p u s h u p pushup pushup
//插入函数
void insert(int& p,int val) //p为根节点的树中,插入值为val的数,必须要引用
{
if(!p) p=getnode(val);
else if(tr[p].val==val) ++tr[p].cnt;
else if(tr[p].val<val) //若当前节点的值小于目标节点,则到右子树查找
{
insert(tr[p].r,val);//因为插入以后,直接相连的右节点可能发生改变,因此要满足堆的性质
if(tr[p].ord<tr[tr[p].r].ord) zag(p);
}
else
{
insert(tr[p].l,val);
if(tr[p].ord<tr[tr[p].l].ord) zig(p);
}
pushup(p); //注意要更新
}
4.删除函数
首先也是查找到待删除的节点,若只有一个或没得子孙节点,直接删除并子承父业;
若均有左右子孙节点,且次数等于一,则要把该节点旋转到以上的情况再删除
那么?是旋到左子树还是右子树呢???
那么就要根据堆的性质来旋转,既要把ord大的值旋转上来,因为是大根堆
//删除函数
void dele(int& p,int val) //找到了要删除的节点,且次数等于1,则要把该节点旋转到叶子节点再删除
{
if(!p) return; //若没找到,尾递归返回
else if(tr[p].val==val)
{
if(tr[p].cnt>1) --tr[p].cnt;
else
{
if(!tr[p].l||!tr[p].r) p=tr[p].l+tr[p].r; //若非满,则子承父业
else if(tr[tr[p].l].ord>tr[tr[p].r].ord)//要根据子节点的ord值,ord值大的翻上来
{
zig(p);
dele(tr[p].r,val);
}
else
{
zag(p);
dele(tr[p].l,val);
}
}
}
else if(tr[p].val<val) dele(tr[p].r,val);
else dele(tr[p].l,val);
pushup(p);
}
5.找前驱与后缀
类似 L C A LCA LCA 的倍增逼近法,由于有哨兵,所以一定能查找到.
int getpre(int val) //找到严格小于val的最大值 ,一定能查找到,因为有哨兵
{
int p=root,res;
while(p)
{
if(tr[p].val<val) res=tr[p].val,p=tr[p].r;
else p=tr[p].l;
}
return res;
}
int getnxt(int val) //找到严格大于val的最小值
{
int p=root,res;
while(p)
{
if(tr[p].val>val) res=tr[p].val,p=tr[p].l;
else p=tr[p].r;
}
return res;
}
若不是严格的前驱后缀,这样:
int getpre(int val)
{
int p=root,res;
while(p)
{
if(tr[p].val<val)
{
res=tr[p].val;
p=tr[p].r;
}
else if(tr[p].val>val) p=tr[p].l;
else
{
res=tr[p].val;
break;
}
}
return res;
}
int getnxt(int val)
{
int p=root,res;
while(p)
{
if(tr[p].val>val)
{
res=tr[p].val;
p=tr[p].l;
}
else if(tr[p].val<val) p=tr[p].r;
else
{
res=tr[p].val;
break;
}
}
return res;
}
6.通过数值找排名,排名找数值
int val_rank(int p,int val) //通过数值找排名
{
if(!p) return 0;
else if(tr[p].val==val) return tr[tr[p].l].sum+1;
else if(tr[p].val>val) return val_rank(tr[p].l,val);
return tr[tr[p].l].sum+tr[p].cnt+val_rank(tr[p].r,val);
}
int rank_val(int p,int rnk) //通过排名找数值
{
if(!p) return INF;
else if(tr[tr[p].l].sum>=rnk) return rank_val(tr[p].l,rnk);
else if(tr[tr[p].l].sum+tr[p].cnt>=rnk) return tr[p].val;
return rank_val(tr[p].r,rnk-tr[tr[p].l].sum-tr[p].cnt);
}
7.初始化
void build()
{
getnode(INF),getnode(-INF);
root=1,tr[root].l=2;
if(tr[root].ord<tr[tr[root].l].ord) zig(root);
}
完整代码
#include<cstdio>
#include<cstring>
#include<iostream>
#include<cmath>
#include<algorithm>
#include<vector>
#include<string>
#include<set>
#include<map>
#include<unordered_map>
#include<queue>
#define me(x,y) memset(x,y,sizeof x)
#define rep(i,x,y) for(i=x;i<=y;++i)
#define repf(i,x,y) for(i=x;i>=y;--i)
#define lowbit(x) -x&x
#define inf 0x3f3f3f3f
#define INF 0x7fffffff
#define f first
#define s second
using namespace std;
typedef long long ll;
typedef long double ld;
typedef pair<int,int> PII;
inline int read()
{
int x=0,f=1;char ch=getchar();
while (ch<'0'||ch>'9'){if (ch=='-') f=-1;ch=getchar();}
while (ch>='0'&&ch<='9'){x=x*10+ch-48;ch=getchar();}
return x*f;
}
struct node
{
int l,r; //l为左节点,r为右节点,0则为空节点
int val,ord; //val为实际权值,odi为堆的随机值
int cnt,sum; //cnt为该值的次数,sum为以该子树的总cnt
};
const int N= 1e5+10;
int n,idx,root; //idx为当前的节点数量,root为根节点的编号
node tr[N];
inline int getnode(int val) //构建新节点,返回节点的编号,因为总是在叶子节点插入新节点
{
tr[++idx]={0,0,val,rand(),1,1};
return idx;
}
//从下至上更新节点值
inline void pushup(int k) //对以k为根节点进行sum更新 ,注意默认tr[k].sum=tr[k].cnt=0
{
tr[k].sum=tr[tr[k].l].sum+tr[tr[k].r].sum+tr[k].cnt;
}
//右旋转zig
void zig(int& p) //注意一定是引用,引用的是一个数组值
{
int tp=tr[p].l; //首先把左节点保存起来
tr[p].l=tr[tp].r; //链接
tr[tp].r=p;
p=tp;
pushup(tr[p].r);
pushup(p);
}
//左旋转
void zag(int& p)
{
int tp=tr[p].r; //首先把右节点保存起来
tr[p].r=tr[tp].l; //链接
tr[tp].l=p;
p=tp;
pushup(tr[p].l);
pushup(p);
}
//插入函数
void insert(int& p,int val) //p为根节点的树中,插入值为val的数,必须要引用
{
if(!p) p=getnode(val);
else if(tr[p].val==val) ++tr[p].cnt;
else if(tr[p].val<val) //若当前节点的值小于目标节点,则到右子树查找
{
insert(tr[p].r,val); //因为插入以后,直接相连的右节点可能发生改变,因此要满足堆的性质
if(tr[p].ord<tr[tr[p].r].ord) zag(p);
}
else
{
insert(tr[p].l,val);
if(tr[p].ord<tr[tr[p].l].ord) zig(p);
}
pushup(p); //注意要更新
}
//删除函数
void dele(int& p,int val) //找到了要删除的节点,且次数等于1,则要把该节点旋转到叶子节点再删除
{
if(!p) return; //若没找到,尾递归返回
else if(tr[p].val==val)
{
if(tr[p].cnt>1) --tr[p].cnt;
else
{
if(!tr[p].l||!tr[p].r) p=tr[p].l+tr[p].r; //若非满,则子承父业
else if(tr[tr[p].l].ord>tr[tr[p].r].ord) //要根据子节点的ord值,ord值大的翻上来
{
zig(p);
dele(tr[p].r,val);
}
else
{
zag(p);
dele(tr[p].l,val);
}
}
}
else if(tr[p].val<val) dele(tr[p].r,val);
else dele(tr[p].l,val);
pushup(p);
}
int getpre(int val) //找到严格小于val的最大值 ,一定能查找到,因为有哨兵
{
int p=root,res;
while(p)
{
if(tr[p].val<val) res=tr[p].val,p=tr[p].r;
else p=tr[p].l;
}
return res;
}
int getnxt(int val) //找到严格大于val的最小值
{
int p=root,res;
while(p)
{
if(tr[p].val>val) res=tr[p].val,p=tr[p].l;
else p=tr[p].r;
}
return res;
}
int val_rank(int p,int val) //通过数值找排名
{
if(!p) return 0;
else if(tr[p].val==val) return tr[tr[p].l].sum+1;
else if(tr[p].val>val) return val_rank(tr[p].l,val);
return tr[tr[p].l].sum+tr[p].cnt+val_rank(tr[p].r,val);
}
int rank_val(int p,int rnk) //通过排名找数值
{
if(!p) return INF;
else if(tr[tr[p].l].sum>=rnk) return rank_val(tr[p].l,rnk);
else if(tr[tr[p].l].sum+tr[p].cnt>=rnk) return tr[p].val;
return rank_val(tr[p].r,rnk-tr[tr[p].l].sum-tr[p].cnt);
}
void build()
{
getnode(INF),getnode(-INF);
root=1,tr[root].l=2;
if(tr[root].ord<tr[tr[root].l].ord) zig(root);
}
int main()
{
int i,j,tag;
n=read(),build();
while(n--)
{
tag=read(),i=read();
switch(tag)
{
case 1:
insert(root,i);
break;
case 2:
dele(root,i);
break;
case 3:
printf("%d\n",val_rank(root,i)-1); //切记
break;
case 4:
++i; //切记
printf("%d\n",rank_val(root,i));
break;
case 5:
printf("%d\n",getpre(i));
break;
case 6:
printf("%d\n",getnxt(i));
break;
}
}
return 0;
}