Treap-平衡树学习笔记

Treap=Tree+Heap

仔细观察Treap这个名字,你会发现他是Tree+Heap的合体
所以他的中文名叫树堆,正如其字面意思
对于权值而言,它是二叉查找树
对于优先级,它是

Treap的每个节点需要维护下列值

int rt;
int ch[maxn][2];//左右孩子指针,0为左孩子,1,为右孩子
int val[maxn],rnd[maxn],size[maxn],cnt[maxn],tot;
//val:权值,cnt:该节点保存相同权值的个数,size:该节点的子树的节点数,rnd:优先级

Treap–旋转

平衡树最重要的操作–旋转
直接按代码的步骤画图模拟一遍就很好理解了

这里以右旋(将p向右旋转)为例,一开始的树长这样
在这里插入图片描述

现在我们用指针K指向他的左子树

在这里插入图片描述

p的左子树变成K的右子树
在这里插入图片描述
K的右子树变成p
在这里插入图片描述

注意这里只在指针K上完成了旋转
我们还要赋值给p,即令p=k

左旋则正好相反,可以尝试自己画一下加深理解

//p:带旋转结点,d==0:左旋, d==1:右旋,旋转见图解 
void rotate(int &p,int d)//记得引用 
{
	int k=ch[p][d^1];
	ch[p][d^1]=ch[k][d];
	ch[k][d]=p;
	update(p); update(k);
	p=k;
}

Treap–插入

既然Treap是二叉查找树
那么插入自然按照二叉查找树的性质
即每个结点的左孩子权值一定小于该节点,而右孩子反之
只要按照二叉查找树的插入方式递归寻找带插入位置就好

然而,Treap还是个堆啊
所以插入后还要判断优先级
之前我们初始化优先级rnd为随机rand()
从而保证了整个程序平均的运行效率

void ins(int &p,int x)//记得引用 
{
	if(!p)//找到带插入节点 
	{ 
		p=++tot; val[p]=x; 
		rnd[p]=rand(); size[p]=cnt[p]=1;
		return;
	}
	if(val[p]==x){ ++size[p]; ++cnt[p]; return;}//若已有权值相同的结点,直接更新 
	int d=x<val[p] ?0:1;//查询应该往哪边插入 ,若x比该节点小,则cmp返回0,反之返回1
	ins(ch[p][d],x);
	if( rnd[ch[p][d]]<rnd[p]) rotate(p,d^1);//维护堆的性质 
	update(p);
}

Treap–删除

寻找待删除结点可以依照二叉查找树的性质
找到待删除节点后有几种情况

1.若结点保存了多个相同的值,则直接更新 s i z e − 1 , c n t − 1 size-1,cnt-1 size1,cnt1即可
2.若结点只有一棵子树,就以该子树代替该节点
3.若结点两棵子树都不为空,则先把优先级较高的子树旋转到根,然后递归在另一子树中继续寻找待删除结点

void del(int &p,int x)//记得引用 
{
	if(!p) return;
	if(val[p]==x)
	{
		if(cnt[p]>1){ --size[p]; --cnt[p]; return;}//若该节点保存了多个相同值,直接更新即可 
		else
		{
			if(!ch[p][0])p=ch[p][1];//如果这个结点只有一棵子树,就以该子树代替该节点
			else if(!ch[p][1])p=ch[p][0];
			else//两棵子树都不为空 
			{
				int dd=rnd[ch[p][0]] < rnd[ch[p][1]] ?1:0;
				rotate(p,dd); del(ch[p][dd],x);
				//先把优先级较高的子树旋转到根,然后递归在另一子树中删除p 
			}
		}
	}
	else if(x<val[p]) del(ch[p][0],x);
	else del(ch[p][1],x);
	if(p) update(p);//删除后一定要更新信息,更新前判断p不为NULL
}

Treap–查询排名

查询数x的排名

int rank(int p,int x)//查询函数都不加引用
{
	int ss=size[ch[p][0]];
	if(x<val[p]) return rank(ch[p][0],x);
	else if(x==val[p]) return ss+1;//找到待查询结点,则其排名为左子树节点数+1
	else return ss+cnt[p]+rank(ch[p][1],x);
	//待查询权值在右子树内
    //则该节点及其左子树必定小于待查询结点
    //所以加上该节点左孩子的size+该点cnt,递归右孩子
}

Treap–查询第K小

查询数x的排名

int kth(int p,int k)
{
	int ss=size[ch[p][0]];
	if(k<=ss) return  kth(ch[p][0],k);//若待查询排名小于左孩子数,则应继续在左子树递归寻找
	else if(k<=ss+cnt[p]) return val[p];//如果排名小于左孩子数加该节点cnt,则该节点为待查询结点
	else return kth(ch[p][1],k-ss-cnt[p]);//如果排名还要大,那就要往右子树找
}

Treap–前驱后继

求x的前驱(前驱定义为小于x,且最大的数)
求x的后继(后继定义为大于x,且最小的数)

int pre(int p,int x)
{
	if(!p) return -inf;
	if(x<=val[p]) return pre(ch[p][0],x);
	else return max(val[p],pre(ch[p][1],x));
}

int nxt(int p,int x)
{
	if(!p) return inf;
	if(x>=val[p]) return nxt(ch[p][1],x);
	else return min(val[p],nxt(ch[p][0],x));
}

一般题目中会出现的Treap操作都包含在了上面
下面给出洛谷 P3369 【模板】普通平衡树的完整代码

#include<iostream>
#include<cstdio>
#include<cstring>
#include<algorithm>
#include<queue>
using namespace std;

int read()
{
	int f=1,x=0;
	char ss=getchar();
	while(ss<'0'||ss>'9'){if(ss=='-')f=-1;ss=getchar();}
	while(ss>='0'&&ss<='9'){x=x*10+ss-'0';ss=getchar();}
	return f*x;
}

const int inf=1e9+7;
const int maxn=500010;
int n,rt,tot;
int rnd[maxn],val[maxn],ch[maxn][2];
int size[maxn],cnt[maxn];

void update(int p){ size[p]=size[ch[p][0]]+size[ch[p][1]]+cnt[p];}

void rotate(int &p,int d)
{
	int k=ch[p][d^1];
	ch[p][d^1]=ch[k][d];
	ch[k][d]=p;
	update(p); update(k);
	p=k;
}

void ins(int &p,int x)
{
	if(!p)
	{
		p=++tot; val[p]=x;
		rnd[p]=rand(); size[p]=cnt[p]=1;
		return;
	}
	if(val[p]==x){ size[p]++; cnt[p]++; return;}
	int d=x<val[p]?0:1;
	ins(ch[p][d],x);
	if(rnd[ch[p][d]]<rnd[p]) rotate(p,d^1);
	update(p);
}

void del(int &p,int x)
{
	if(!p) return;
	if(val[p]==x)
	{
		if(cnt[p]>1){ size[p]--; cnt[p]--; return;}
		else
		{
			if(!ch[p][1]) p=ch[p][0];
			else if(!ch[p][0]) p=ch[p][1];
			else
			{
				int dd=rnd[ch[p][0]]<rnd[ch[p][1]]?1:0;
				rotate(p,dd); del(ch[p][dd],x);
			} 
		}
	}
	else if(x<val[p]) del(ch[p][0],x);
	else del(ch[p][1],x);
	if(p) update(p);
}

int rank(int p,int x)
{
	int ss=size[ch[p][0]];
	if(x<val[p]) return rank(ch[p][0],x);
	else if(x==val[p]) return ss+1;
	else return ss+cnt[p]+rank(ch[p][1],x);
}

int kth(int p,int k)
{
	int ss=size[ch[p][0]];
	if(k<=ss) return kth(ch[p][0],k);
	else if(k<=ss+cnt[p]) return val[p];
	else return kth(ch[p][1],k-ss-cnt[p]);
}

int pre(int p,int x)
{
	if(!p) return -inf;
	if(x<=val[p]) return pre(ch[p][0],x);
	else return max(val[p],pre(ch[p][1],x));
}

int nxt(int p,int x)
{
	if(!p) return inf;
	if(x>=val[p]) return nxt(ch[p][1],x);
	else return min(val[p],nxt(ch[p][0],x));
}

int main()
{
	n=read();
	while(n--)
	{
		int opt=read(),x=read();
		if(opt==1) ins(rt,x);
		else if(opt==2) del(rt,x);
		else if(opt==3) printf("%d\n",rank(rt,x));
		else if(opt==4) printf("%d\n",kth(rt,x));
		else if(opt==5) printf("%d\n",pre(rt,x));
		else if(opt==6) printf("%d\n",nxt(rt,x));
	}
	return 0;
}

附赠一个指针版的

#include<iostream>
#include<cstdio>
#include<cmath>
#include<queue>
#include<algorithm>
#include<cstring>
using namespace std;

int read()
{
    int f=1,x=0;
    char ss=getchar();
    while(ss<'0'||ss>'9'){if(ss=='-')f=-1;ss=getchar();}
    while(ss>='0'&&ss<='9'){x=x*10+ss-'0';ss=getchar();}
    return f*x;
}

const int inf=2e9;
int n; 
struct node
{
    node* ch[2];
    int size,cnt,rnd,val;
    node(int v): val(v) { size=cnt=1; rnd=rand(); ch[0]=ch[1]=NULL;}
    int cmp(int x){ if(x==val)return -1; return x<val?0:1;}
    void update()
    {
        size=cnt;
        if(ch[0]!=NULL) size+=ch[0]->size;
        if(ch[1]!=NULL) size+=ch[1]->size;
    }
};
node *rt=NULL;

void rotate(node* &p,int d)
{
    node* k=p->ch[d^1];
    p->ch[d^1]=k->ch[d];
    k->ch[d]=p;
    p->update(); k->update();
    p=k;
}

void ins(node* &p,int x)
{
    if(p==NULL){ p=new node(x); return;}
    if(x==p->val){ ++p->size; ++p->cnt; return;}
    int d=p->cmp(x);
    ins(p->ch[d],x);
    if(p->ch[d]->rnd < p->rnd) rotate(p,d^1);
    p->update();
}

void del(node* &p,int x)
{
    if(p==NULL) return;
    if(x==p->val)
    {
        if(p->cnt>1){ --p->size; --p->cnt; return;}
        if(p->ch[0]==NULL){ node* k=p; p=p->ch[1]; delete(k);}
        else if(p->ch[1]==NULL){ node* k=p; p=p->ch[0]; delete(k);}
        else
        {
            int dd=p->ch[0]->rnd < p->ch[1]->rnd ?1:0;
            rotate(p,dd); del(p->ch[dd],x);
        }
    }
    else if(x<p->val) del(p->ch[0],x);
    else del(p->ch[1],x);
    if(p!=NULL) p->update();
}

int rank(node* p,int x)
{
    int ss=p->ch[0]==NULL?0:p->ch[0]->size;
    if(x==p->val) return ss+1;
    else if(x<p->val) return rank(p->ch[0],x);
    else return ss+p->cnt+rank(p->ch[1],x);
}

int kth(node *p,int k)
{
    int ss=p->ch[0]==NULL?0:p->ch[0]->size;
    if(k<=ss) return kth(p->ch[0],k);
    else if(k<=ss+p->cnt) return p->val;
    else return kth(p->ch[1],k-ss-p->cnt);
}

int pre(node* p,int x)
{
    if(p==NULL) return -inf;
    int d=p->cmp(x);
    if(d==-1||d==0) return pre(p->ch[0],x);
    else return max(p->val,pre(p->ch[1],x));
}

int nxt(node* p,int x)
{
    if(p==NULL) return inf;
    int d=p->cmp(x);
    if(d==-1||d==1) return nxt(p->ch[1],x);
    else return min(p->val,nxt(p->ch[0],x));
}

int main()
{
    n=read();
    while(n--)
    {
        int k=read(),x=read();
        if(k==1) ins(rt,x);
        else if(k==2) del(rt,x);
        else if(k==3) printf("%d\n",rank(rt,x));
        else if(k==4) printf("%d\n",kth(rt,x));
        else if(k==5) printf("%d\n",pre(rt,x));
        else if(k==6) printf("%d\n",nxt(rt,x));
    }
    return 0;
}
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值