平衡树——treap

treap实际上就是tree(BST,二叉搜索树)+heap(堆)

我们维护一个二叉树来储存值,但是为了避免二叉树由于值太特殊变成链式结构,我们对于每个点加入一个val值,这个是随机值,我们通过这个随机值来维护一个大根堆(只与val有关的大根堆),进而使得我们维护能够用一个比较对称的二叉树来维护所有数据,可以类比AVL树来思考。

那么我们为什么要用treap,或者换句话说treap可以进行哪些操作呢?

1.插入

2.删除

3.找前驱和后继(小于某个数最大的数和大于某个数最小的数,BST中不同节点数不同)

4.找最大/最小

5.求某个值的排名

6.求排名是rank的数是哪个

7.比某个数小的最大值(数可能不在我们维护的区间中)

8.比某个数大的最小值(同上)

对于这些操作,显然前几个是可以用set实现的,但是求某个值的排名和求排名是k的数是哪个,这两个操作显然用set没办法很好的实现。所以我们必须要使用Treap。能用set的一定可以用Treap,能用Treap的不一定能用set。

那么我们就要来看如何实现Treap,Treap是基于BST的,那么我们先来回顾一下BST,BST就是维护一棵二叉树,对于这个二叉树而言,它左子节点的值一定小于它,右子节点的值一定大于它。

我们一般是以第一个数为根节点的,但是如果要维护的区间是单调的,那么二叉树就会退化成链式结构,很影响时间复杂度。所以为了避免二叉树退化成链式结构,我们在每个点上增加一个随机值变量,然后根据这个随机值维护一个大根堆,也即父节点的值一定大于左右子节点。然后如果不满足的话,那么就把左右子节点中的大的点通过旋转操作放到根节点上去,旋转操作如下:

可以发现即使我们进行了旋转操作,维护的仍然是一个二叉树,不会对我们在二叉树上的操作产生影响。与此同时,这也给了我们一个删除中间节点的思路,我们可以把中间节点换到叶子节点上进行删除。

那么我们下面来看代码层面的实现:

首先是对每个节点的定义,最基本的需要一下几个值(当然由于题目不同,可能还要加一些别的变量):

struct node{
    int l,r;//l,r表示左右子节点的下标,我们这里用下标表示指针
    int key,val;//key表示我们实际存的值,val存的是我们随机赋的值,用以维护堆
}

 然后我们来看有哪些基本操作:

新增节点:

int get_node()
{
    tr[++idx].key=key;
    tr[idx].val=rand();
    return idx;
}

左旋:

右到父,父到左,右左到左右。先拔再转。

void zag(int &p)//左旋
{
    int q=tr[p].r;
    tr[p].r=tr[q].l,tr[q].l=p,p=q;
    pushup(tr[p].l),pushup(p);
}

 通过前三个赋值,我们已经实现了子节点的交换,相对于原来的p的父节点的更新我们该怎么实现了,显然我们没有存父节点的信息, 不能从子节点访问到父节点,但是我们肯定是在递归中实现的,我们只要在递归的时候传入tr[u].l的引用,那么在这一层进行修改的时候就可以实现对tr[u].l的修改。

右旋:

void zig(int &p)//右旋
{
    int q=tr[p].l;
    tr[p].l=tr[q].r,tr[q].r=p,p=q;
    pushup(tr[p].r),pushup(p);
}

这里的pushup就是用子节点更新父节点的一些信息,和线段树中的pushup差不多。

初始化:

初始化的时候,我们设置了两个哨兵,它们的值分别赋值为正无穷和负无穷,先令root为负无穷对应的点1,它的右子为正无穷对应的点2.然后根据它们的val,判断是否需要左旋。

void build()
{
    get_node(-INF),get_node(INF);
    root=1,tr[root].r=2;
    pushup(root);
    if(tr[1].val<tr[2].val) zag(root);//维护大根堆
}

插入

和二叉树的插入操作一样,注意判断是否需要旋转操作即可。

void insert(int &p,int key)
{
    if(!p) get_node(key);
    else if(tr[p].key<key)
    {   
        insert(tr[p].r,key);
        if(tr[p].val<tr[tr[p].r].val) zig(p);
    }
    else 
    {   
        insert(tr[p].l,key);
        if(tr[p].val<tr[tr[p].l].val) zag(p);
    }
}

删除

删除要分三种情况,一种是要删除的数不存在,那么最后搜到的位置就是0,那么直接返回即可;一种是要删除的数在叶子节点上,那么我们直接把这个叶子节点的下标修改成0即可,因为我们在上一层传入的是它父节点的tr[u].l或者tr[u].r的引用,所以这里直接赋成0就相当于将它父节点的左指针或者右指针的指向修改成空;最后一种最麻烦,就是要删除的数不在叶子节点上,那么就需要通过旋转将它换到叶子节点上,然后再用第二种方法执行删除。

void remove(int &p,int key)
{
    if(!p) return;
    if(tr[p].l||tr[p.r])
    {
        if(!tr[p].r||tr[tr[p].l].val>tr[tr[p].r].val)
        {
            zig(p);//p和原p的父节点的某个子节点都被修改成了原p的左子节点
            remove(tr[p].r,key);
        }
        else
        {
            zag(p);
            remove(tr[p].l,key);
        }
    }
    else p=0;
}

找严格小于key的数中最大的数

int get_prev(int p,int key)
{
    if(!p) return -INF;//没有小于key的数
    if(tr[p].key>=key) return get_prev(tr[p].l,key);
    return max(tr[p].key,get_prev(tr[p].r,key));
}

找到严格大于key的最小数

int get_next(int p,int key)
{
    if(!p) return INF;
    if(p.key<=key) return get_next(tr[p].r,key);
    return min(tr[p].key, get_next(tr[p].l,key) );
}

找最大:

int get_max(int p)
{
    if(!p) return -INF;
    return max(tr[p].key,get_max(tr[p].r));
}

找最小:

int get_min(int p)
{
    if(!p) return INF;
    return min(tr[p].key,get_max(tr[p].l));
}

求某个值的排名

求排名的时候为了方便,我们引入了以某个点为根节点的子树的大小size,如果有重复元素,我们还可以引入一个cnt变量,表示值为当前节点的点有多少个。

int get_rank_by_key(int p,int key)
{
    if(!p) return 0;
    if(tr[p].key==key) return tr[tr[p].l].size+1; 
    if(tr[p].key>key) return get_rank_by_key(tr[p].l,key);
    return  tr[tr[p].l].size()+get_rank_by_key(tr[p].r,key);
}

求排名是rank的数是哪个

int get_key_by_rank(int p,int rank)
{
    if(!p) return INF;
    if(tr[tr[p].l].size >= rank) return get_key_by_rank(tr[p].l,rank);
    if(tr[p].size+tr[p].cnt>=key) return tr[p].key;//cnt表示的是当前数有多少个
    return  get_key_by_rank(tr[p].r,rank-tr[tr[p].l].size-tr[p].cnt);
}

差不多就是这些操作,剩下的我们根据具体的题目再来分析。

253. 普通平衡树(253. 普通平衡树 - AcWing题库

这个就是一道比较裸的题目,这里要注意我们需要在原来节点定义的基础上引入size和cnt,因为这里有根据排名求数和根据数求排名两个操作。

#include<bits/stdc++.h>
using namespace std;
const int inf=0x3f3f3f3f;
struct node{
    int l,r;
    int key,val;
    int size,cnt;
}tr[100010];
int root,idx;
int get_node(int key)
{
   tr[++idx].key=key;
   tr[idx].val=rand();
   tr[idx].size=tr[idx].cnt=1;
   return idx;
}
void pushup(int p)
{
    tr[p].size=tr[tr[p].l].size+tr[p].cnt+tr[tr[p].r].size;
}
void left(int &p)//左旋
{
    int q=tr[p].r;
    tr[p].r=tr[q].l,tr[q].l=p,p=q;
    pushup(tr[p].l),pushup(p);
}
void right(int &p)//右旋
{
    int q=tr[p].l;
    tr[p].l=tr[q].r,tr[q].r=p,p=q;
    pushup(tr[p].r),pushup(p);
}
void build()
{
    get_node(-inf),get_node(inf);
    root=1,tr[1].r=2;
    pushup(root);
    if(tr[1].val<tr[2].val) left(root);
}
void insert(int &p,int key)
{
    if(!p) p=get_node(key);//这里赋值了才算被插进去,因为p传入的是引用
    else if(tr[p].key==key) tr[p].cnt++;
    else if(tr[p].key>key) 
    {
        insert(tr[p].l,key);
        if(tr[p].val<tr[tr[p].l].val) right(p);
    }
    else 
    {
        insert(tr[p].r,key);
        if(tr[p].val<tr[tr[p].r].val) left(p);
    }
    pushup(p);
}
void remove(int &p,int key)
{
    if(!p) return;
    if(tr[p].key==key)
    {
        if(tr[p].cnt>1) tr[p].cnt--;
        else if(tr[p].l||tr[p].r)
        {
                if(!tr[p].r||tr[tr[p].l].val>tr[tr[p].r].val)
                {
                    right(p);
                    remove(tr[p].r,key);
                }
                else 
                {
                    left(p);
                    remove(tr[p].l,key);
                }
            
            
        }else p=0;
    }
    else if(tr[p].key>key) remove(tr[p].l,key);
    else remove(tr[p].r,key);
    pushup(p);
}
int get_rank_by_key(int p,int key)
{
    if(!p) return 0;
    if(tr[p].key==key) return tr[tr[p].l].size+1;
    else if(tr[p].key>key) return get_rank_by_key(tr[p].l,key);
    else return tr[tr[p].l].size+tr[p].cnt+get_rank_by_key(tr[p].r,key);
}
int get_key_by_rank(int p,int rank)
{
    if(!p) return inf;
    if(tr[tr[p].l].size>=rank) return get_key_by_rank(tr[p].l,rank);
    else if(tr[tr[p].l].size+tr[p].cnt>=rank) return tr[p].key;
    else get_key_by_rank(tr[p].r,rank-tr[p].cnt-tr[tr[p].l].size);
}
int get_prev(int p,int key)
{
    if(!p) return -inf;
    if(tr[p].key>=key) return get_prev(tr[p].l,key);
    else return max(tr[p].key,get_prev(tr[p].r,key));
}
int get_next(int p,int key)
{
    if(!p) return inf;
    if(tr[p].key<=key) return get_next(tr[p].r,key);
    else return min(tr[p].key,get_next(tr[p].l,key));
}
int main()
{
    build();
    int n;
    scanf("%d",&n);
    while(n--)
    {
        int op,x;
        scanf("%d%d",&op,&x);
        if(op==1) insert(root,x);
        else if(op==2) remove(root,x);
        else if(op==3) cout<<get_rank_by_key(root,x)-1<<endl;//因为有哨兵
        else if(op==4) cout<<get_key_by_rank(root,x+1)<<endl;
        else if(op==5) cout<<get_prev(root,x)<<endl;
        else cout<<get_next(root,x)<<endl;
    }
}

265. 营业额统计(265. 营业额统计 - AcWing题库

思路:这题对于每个数要找到在它插入前距离它最近的数,那么实际上就是在每个数插入前找它的前驱(最大的最小值)和后继(最小的最大值),然后取离它更近的那个值。

#include<bits/stdc++.h>
using namespace std;
const int inf=0x3f3f3f3f;
struct node{
    int l,r;
    int key,val;
}tr[100010];
int root,idx;
int get_node(int key)
{
    tr[++idx].key=key;
    tr[idx].val=rand();
    return idx;
}
void left(int &p)
{
    int q=tr[p].r;
    tr[p].r=tr[q].l,tr[q].l=p,p=q;
}
void right(int &p)
{
    int q=tr[p].l;
    tr[p].l=tr[q].r,tr[q].r=p,p=q;
}
void build()
{
    get_node(-inf),get_node(inf);
    root=1,tr[1].r=2;
    if(tr[1].val<tr[2].val) left(root);
}
void insert(int &p,int key)
{
    if(!p) p=get_node(key);
    else if(tr[p].key==key) return;
    else if(tr[p].key>key) 
    {
        insert(tr[p].l,key);
        if(tr[tr[p].l].val>tr[p].val) right(p);
    }
    else 
    {
        insert(tr[p].r,key);
        if(tr[tr[p].r].val>tr[p].val) left(p);
    }
}
int get_prev(int p,int key)
{
    if(!p) return -inf;
    if(tr[p].key>key) return get_prev(tr[p].l,key);
    return max(tr[p].key,get_prev(tr[p].r,key));
}
int get_next(int p,int key)
{
    if(!p) return inf;
    if(tr[p].key<key) return get_next(tr[p].r,key);
    return min(tr[p].key,get_next(tr[p].l,key));
}
int main()
{
    build();
    int n;
    scanf("%d",&n);
    long long ans=0;
    for(int i=1;i<=n;i++)
    {
        int x;
        scanf("%d",&x);
        if(i==1) ans+=x; 
        else ans+=min(x-get_prev(root,x),get_next(root,x)-x);
        insert(root,x);
    }
    printf("%lld",ans);
}

在插入的时候,一定不要忘记判断是否需要左右旋。

  • 20
    点赞
  • 30
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值