Treap=Tree+Heap
仔细观察Treap这个名字,你会发现他是Tree+Heap的合体
所以他的中文名叫树堆,正如其字面意思
对于权值而言,它是二叉查找树
对于优先级,它是堆
Treap的每个节点需要维护下列值
int rt;
int ch[maxn][2];//左右孩子指针,0为左孩子,1,为右孩子
int val[maxn],rnd[maxn],size[maxn],cnt[maxn],tot;
//val:权值,cnt:该节点保存相同权值的个数,size:该节点的子树的节点数,rnd:优先级
Treap–旋转
平衡树最重要的操作–旋转
直接按代码的步骤画图模拟一遍就很好理解了
这里以右旋(将p向右旋转)为例,一开始的树长这样
现在我们用指针K指向他的左子树
令p的左子树变成K的右子树
令K的右子树变成p
注意这里只在指针K上完成了旋转
我们还要赋值给p,即令p=k
左旋则正好相反,可以尝试自己画一下加深理解
//p:带旋转结点,d==0:左旋, d==1:右旋,旋转见图解
void rotate(int &p,int d)//记得引用
{
int k=ch[p][d^1];
ch[p][d^1]=ch[k][d];
ch[k][d]=p;
update(p); update(k);
p=k;
}
Treap–插入
既然Treap是二叉查找树
那么插入自然按照二叉查找树的性质
即每个结点的左孩子权值一定小于该节点,而右孩子反之
只要按照二叉查找树的插入方式递归寻找带插入位置就好
然而,Treap还是个堆啊
所以插入后还要判断优先级
之前我们初始化优先级rnd为随机rand()
从而保证了整个程序平均的运行效率
void ins(int &p,int x)//记得引用
{
if(!p)//找到带插入节点
{
p=++tot; val[p]=x;
rnd[p]=rand(); size[p]=cnt[p]=1;
return;
}
if(val[p]==x){ ++size[p]; ++cnt[p]; return;}//若已有权值相同的结点,直接更新
int d=x<val[p] ?0:1;//查询应该往哪边插入 ,若x比该节点小,则cmp返回0,反之返回1
ins(ch[p][d],x);
if( rnd[ch[p][d]]<rnd[p]) rotate(p,d^1);//维护堆的性质
update(p);
}
Treap–删除
寻找待删除结点可以依照二叉查找树的性质
找到待删除节点后有几种情况
1.若结点保存了多个相同的值,则直接更新
s
i
z
e
−
1
,
c
n
t
−
1
size-1,cnt-1
size−1,cnt−1即可
2.若结点只有一棵子树,就以该子树代替该节点
3.若结点两棵子树都不为空,则先把优先级较高的子树旋转到根,然后递归在另一子树中继续寻找待删除结点
void del(int &p,int x)//记得引用
{
if(!p) return;
if(val[p]==x)
{
if(cnt[p]>1){ --size[p]; --cnt[p]; return;}//若该节点保存了多个相同值,直接更新即可
else
{
if(!ch[p][0])p=ch[p][1];//如果这个结点只有一棵子树,就以该子树代替该节点
else if(!ch[p][1])p=ch[p][0];
else//两棵子树都不为空
{
int dd=rnd[ch[p][0]] < rnd[ch[p][1]] ?1:0;
rotate(p,dd); del(ch[p][dd],x);
//先把优先级较高的子树旋转到根,然后递归在另一子树中删除p
}
}
}
else if(x<val[p]) del(ch[p][0],x);
else del(ch[p][1],x);
if(p) update(p);//删除后一定要更新信息,更新前判断p不为NULL
}
Treap–查询排名
查询数x的排名
int rank(int p,int x)//查询函数都不加引用
{
int ss=size[ch[p][0]];
if(x<val[p]) return rank(ch[p][0],x);
else if(x==val[p]) return ss+1;//找到待查询结点,则其排名为左子树节点数+1
else return ss+cnt[p]+rank(ch[p][1],x);
//待查询权值在右子树内
//则该节点及其左子树必定小于待查询结点
//所以加上该节点左孩子的size+该点cnt,递归右孩子
}
Treap–查询第K小
查询数x的排名
int kth(int p,int k)
{
int ss=size[ch[p][0]];
if(k<=ss) return kth(ch[p][0],k);//若待查询排名小于左孩子数,则应继续在左子树递归寻找
else if(k<=ss+cnt[p]) return val[p];//如果排名小于左孩子数加该节点cnt,则该节点为待查询结点
else return kth(ch[p][1],k-ss-cnt[p]);//如果排名还要大,那就要往右子树找
}
Treap–前驱后继
求x的前驱(前驱定义为小于x,且最大的数)
求x的后继(后继定义为大于x,且最小的数)
int pre(int p,int x)
{
if(!p) return -inf;
if(x<=val[p]) return pre(ch[p][0],x);
else return max(val[p],pre(ch[p][1],x));
}
int nxt(int p,int x)
{
if(!p) return inf;
if(x>=val[p]) return nxt(ch[p][1],x);
else return min(val[p],nxt(ch[p][0],x));
}
一般题目中会出现的Treap操作都包含在了上面
下面给出洛谷 P3369 【模板】普通平衡树的完整代码
#include<iostream>
#include<cstdio>
#include<cstring>
#include<algorithm>
#include<queue>
using namespace std;
int read()
{
int f=1,x=0;
char ss=getchar();
while(ss<'0'||ss>'9'){if(ss=='-')f=-1;ss=getchar();}
while(ss>='0'&&ss<='9'){x=x*10+ss-'0';ss=getchar();}
return f*x;
}
const int inf=1e9+7;
const int maxn=500010;
int n,rt,tot;
int rnd[maxn],val[maxn],ch[maxn][2];
int size[maxn],cnt[maxn];
void update(int p){ size[p]=size[ch[p][0]]+size[ch[p][1]]+cnt[p];}
void rotate(int &p,int d)
{
int k=ch[p][d^1];
ch[p][d^1]=ch[k][d];
ch[k][d]=p;
update(p); update(k);
p=k;
}
void ins(int &p,int x)
{
if(!p)
{
p=++tot; val[p]=x;
rnd[p]=rand(); size[p]=cnt[p]=1;
return;
}
if(val[p]==x){ size[p]++; cnt[p]++; return;}
int d=x<val[p]?0:1;
ins(ch[p][d],x);
if(rnd[ch[p][d]]<rnd[p]) rotate(p,d^1);
update(p);
}
void del(int &p,int x)
{
if(!p) return;
if(val[p]==x)
{
if(cnt[p]>1){ size[p]--; cnt[p]--; return;}
else
{
if(!ch[p][1]) p=ch[p][0];
else if(!ch[p][0]) p=ch[p][1];
else
{
int dd=rnd[ch[p][0]]<rnd[ch[p][1]]?1:0;
rotate(p,dd); del(ch[p][dd],x);
}
}
}
else if(x<val[p]) del(ch[p][0],x);
else del(ch[p][1],x);
if(p) update(p);
}
int rank(int p,int x)
{
int ss=size[ch[p][0]];
if(x<val[p]) return rank(ch[p][0],x);
else if(x==val[p]) return ss+1;
else return ss+cnt[p]+rank(ch[p][1],x);
}
int kth(int p,int k)
{
int ss=size[ch[p][0]];
if(k<=ss) return kth(ch[p][0],k);
else if(k<=ss+cnt[p]) return val[p];
else return kth(ch[p][1],k-ss-cnt[p]);
}
int pre(int p,int x)
{
if(!p) return -inf;
if(x<=val[p]) return pre(ch[p][0],x);
else return max(val[p],pre(ch[p][1],x));
}
int nxt(int p,int x)
{
if(!p) return inf;
if(x>=val[p]) return nxt(ch[p][1],x);
else return min(val[p],nxt(ch[p][0],x));
}
int main()
{
n=read();
while(n--)
{
int opt=read(),x=read();
if(opt==1) ins(rt,x);
else if(opt==2) del(rt,x);
else if(opt==3) printf("%d\n",rank(rt,x));
else if(opt==4) printf("%d\n",kth(rt,x));
else if(opt==5) printf("%d\n",pre(rt,x));
else if(opt==6) printf("%d\n",nxt(rt,x));
}
return 0;
}
附赠一个指针版的
#include<iostream>
#include<cstdio>
#include<cmath>
#include<queue>
#include<algorithm>
#include<cstring>
using namespace std;
int read()
{
int f=1,x=0;
char ss=getchar();
while(ss<'0'||ss>'9'){if(ss=='-')f=-1;ss=getchar();}
while(ss>='0'&&ss<='9'){x=x*10+ss-'0';ss=getchar();}
return f*x;
}
const int inf=2e9;
int n;
struct node
{
node* ch[2];
int size,cnt,rnd,val;
node(int v): val(v) { size=cnt=1; rnd=rand(); ch[0]=ch[1]=NULL;}
int cmp(int x){ if(x==val)return -1; return x<val?0:1;}
void update()
{
size=cnt;
if(ch[0]!=NULL) size+=ch[0]->size;
if(ch[1]!=NULL) size+=ch[1]->size;
}
};
node *rt=NULL;
void rotate(node* &p,int d)
{
node* k=p->ch[d^1];
p->ch[d^1]=k->ch[d];
k->ch[d]=p;
p->update(); k->update();
p=k;
}
void ins(node* &p,int x)
{
if(p==NULL){ p=new node(x); return;}
if(x==p->val){ ++p->size; ++p->cnt; return;}
int d=p->cmp(x);
ins(p->ch[d],x);
if(p->ch[d]->rnd < p->rnd) rotate(p,d^1);
p->update();
}
void del(node* &p,int x)
{
if(p==NULL) return;
if(x==p->val)
{
if(p->cnt>1){ --p->size; --p->cnt; return;}
if(p->ch[0]==NULL){ node* k=p; p=p->ch[1]; delete(k);}
else if(p->ch[1]==NULL){ node* k=p; p=p->ch[0]; delete(k);}
else
{
int dd=p->ch[0]->rnd < p->ch[1]->rnd ?1:0;
rotate(p,dd); del(p->ch[dd],x);
}
}
else if(x<p->val) del(p->ch[0],x);
else del(p->ch[1],x);
if(p!=NULL) p->update();
}
int rank(node* p,int x)
{
int ss=p->ch[0]==NULL?0:p->ch[0]->size;
if(x==p->val) return ss+1;
else if(x<p->val) return rank(p->ch[0],x);
else return ss+p->cnt+rank(p->ch[1],x);
}
int kth(node *p,int k)
{
int ss=p->ch[0]==NULL?0:p->ch[0]->size;
if(k<=ss) return kth(p->ch[0],k);
else if(k<=ss+p->cnt) return p->val;
else return kth(p->ch[1],k-ss-p->cnt);
}
int pre(node* p,int x)
{
if(p==NULL) return -inf;
int d=p->cmp(x);
if(d==-1||d==0) return pre(p->ch[0],x);
else return max(p->val,pre(p->ch[1],x));
}
int nxt(node* p,int x)
{
if(p==NULL) return inf;
int d=p->cmp(x);
if(d==-1||d==1) return nxt(p->ch[1],x);
else return min(p->val,nxt(p->ch[0],x));
}
int main()
{
n=read();
while(n--)
{
int k=read(),x=read();
if(k==1) ins(rt,x);
else if(k==2) del(rt,x);
else if(k==3) printf("%d\n",rank(rt,x));
else if(k==4) printf("%d\n",kth(rt,x));
else if(k==5) printf("%d\n",pre(rt,x));
else if(k==6) printf("%d\n",nxt(rt,x));
}
return 0;
}