不GG的Splay

Splay(伸展树)是一种维护二叉搜索树的数据结构,可以用它干一些很神奇的东西,这篇文章先来介绍它的基本操作。

首先定义几个变量:

  • fa[x] 表示 x 的父节点

  • ch[x][y] 表示 x 的儿子节点,y=0 表示左儿子,y=1 表示右儿子

  • cnt[x] 表示 x 这个数出现了几次

  • val[x] 表示 x 节点的权值是多少

  • size[x] 表示以 x 为根的树节点个数(树的大小)

  • tot_size 表示树的总大小

  • root 表示当前根节点是哪个


下面介绍操作:

clear(x)

把 x 节点上的所有信息清空

void clear(int x) {
    fa[x]=ch[x][0]=ch[x][1]=cnt[x]=size[x]=val[x]=0;
}

get(x)

判断 x 节点为它父节点的左儿子还是右儿子(左0右1)

int get(int x) {
    return ch[fa[x]][1]==x;
}

update(x)

维护以 x 为根的树的大小

在下面的操作的时候如果会update很多点,一定要从下往上维护。

void update(int x) {
    if(x) {
        size[x]=cnt[x];
        if(ch[x][0]) size[x]+=size[ch[x][0]];
        if(ch[x][1]) size[x]+=size[ch[x][1]];
    }
}

rotate(x)

Splay中最最最重要的一个环节。

把 x 节点旋转到 x 的父节点的位置。

可是这是二叉树呀,这样操作不就乱了吗?

所以我们要维护某些节点之间的父子关系。

首先我们要明确在这次操作中会涉及到的节点:

1、x,你就是转它肯定会涉及它呀

2、fa[x],你要把 x 转到那里肯定也会涉及到它

3、fa[fa[x]],把 x 转到了 fa[x] 时,fa[fa[x]] 的儿子就不是 fa[x] 了,会变成 x

好了,rotate操作就会涉及到这 3 个节点,每个节点改变它的父亲和儿子,就会有六条语句,其中如果 fa[x] 已经是根了,那么就不用改变 fa[fa[x]] 的儿子了。

最关键的问题来了:父子关系怎么分配呢?

因为我们想把 x 到 fa[x] 的位置,那么它们的的父子关系必然会互换。唯一要确定的就是左右儿子的问题。

假设 x 是 fa[x] 的左儿子,那么 fa[x] 的原本的左儿子 x 将会变成 x 的右儿子(这里为什么是右儿子,因为这样才会保持二叉搜索树的性质)。反之亦然,所以我们用 which 来记录 x 与 fa[x] 的关系,最后维护一下旋转后的树的大小(因为 fa[x] 已经是 x 的儿子了,所以先update(fa[x])),代码如下:

void rotate(int x) {
    int pa=fa[x],papa=fa[pa],which=get(x);
    ch[pa][which]=ch[x][!which];fa[ch[x][!which]]=pa;
    ch[x][!which]=pa;fa[pa]=x;
    fa[x]=papa;if(papa) ch[papa][ch[papa][1]==pa]=x;
    update(pa);update(x);
}

splay(x)

这个函数是通过不断的rotate把 x 转到根的位置。

注意三点一线的时候是先转fa[x]再转x

void splay(int x) {
    for(int f;f=fa[x];rotate(x)) {
        if(fa[f]) rotate(get(f)==get(x)?f:x);
    }
    root=x;
}

insert(x)

插入一个数x。

三种情况:

1、**空树。**直接改改信息return就好了。

2、**x重复。**cnt[x]++,维护一下return。

3、**找到了最底下。**新开节点维护一下return。

这个具体看代码吧,应该很好理解。

下面两种情况不要忘记splay一下。

void insert(int v) {
    if(root==0) {
        tot_size++;
        ch[tot_size][0]=ch[tot_size][1]=fa[tot_size]=0;
        val[tot_size]=v;
        cnt[tot_size]=size[tot_size]=1;
        root=tot_size;
        return;
    }
    int f=0,now=root;
    while(true) {
        if(val[now]==v) {
            cnt[now]++;
            update(now);
            update(f);
            splay(now);
            return;
        }
        f=now;
        now=ch[now][val[now]<v];
        if(now==0) {
            tot_size++;
            ch[tot_size][0]=ch[tot_size][1]=0;
            fa[tot_size]=f;
            val[tot_size]=v;
            cnt[tot_size]=1,size[tot_size]=1;
            ch[f][val[f]<v]=tot_size;
            update(f);
            splay(tot_size);
            return;
        }
    }
}

find(x)

查找x这个数的排名

就按照二叉搜索树的性质往下查找,注意我们在往左子树找的时候是不用累加结果的,因为最左边的就是第一个,在往右边找的时候再加上左子树的大小,找到的时候别忘了把 x splay到根方便以后的操作。

int find(int x) {
    int res=0,now=root;
    while(true) {    
        if(x<val[now]) {
            now=ch[now][0];
        }
        else {
            res+=size[ch[now][0]];
            if(x==val[now]) {
                splay(now);
                return res+1;
            }
            res+=cnt[now];
            now=ch[now][1];
        }
    }
}

findx(x)

查找排名为x的树的节点

和find类似无非就是多判断一下子树的大小看看能否继续查找,temp表示的是已经搜了多少个节点。

int findx(int p) {
    int now=root;
    while(true) {
        if(ch[now][0] && p<=size[ch[now][0]]) {
            now=ch[now][0];
        }
        else {
            int temp=size[ch[now][0]]+cnt[now];
            if(p<=temp) return val[now];
            p-=temp;
            now=ch[now][1];
        }
    }
}

pre() 和 next()

查找根节点的前驱和后继节点

如果要查找x的前驱或后继的话,就先insert(x),把它转到根,再del(x),删除。

这个操作很简单,根节点的前驱就是根节点左子树中最靠右的那个,后继就是右子树中最靠左的那个,想一想,为什么?

int pre() {
    int now=ch[root][0];
    while(ch[now][1]) now=ch[now][1];
    return now;
}

int next() {
    int now=ch[root][1];
    while(ch[now][0]) now=ch[now][0];
    return now;
}

del(x)

删除大小为x的节点

首先我们随便find一下,目的是让x转到根节点,现在root就是x。

然后就会出现下面几种情况:

1、**x有重复。**那么直接cnt[root]–,return就好了。

2、**root没有儿子了,即树上只有x一个节点。**那么直接删除根节点,return。

3、**root只有左儿子或只有右儿子。**那就把它的这个儿子变成父亲,然后删除父亲,return。

4、**root有两个儿子。**那么为了满足二叉搜索树的性质,我们把根的前驱变成新的根,再把原来根的右子树接到新根的右儿子上,最后删除原来的根,维护一下新根,return。

void del(int x) {
    int gg=find(x);
    if(cnt[root]>1) {
        cnt[root]--;
        return;
    }
    if(!ch[root][0] && !ch[root][1]) {
        clear(root);
        root=0;
        return;
    }
    if(!ch[root][0]) {
        int oldroot=root;
        root=ch[root][1];
        fa[root]=0;
        clear(oldroot);
        return;
    }
    else if(!ch[root][1]) {
        int oldroot=root;
        root=ch[root][0];
        fa[root]=0;
        clear(oldroot);
        return;
    }
    int oldroot=root;
    splay(pre());
    fa[ch[oldroot][1]]=root;
    ch[root][1]=ch[oldroot][1];
    clear(oldroot);
    update(root);
    return;
}

最后整合成一道模板题

#include <iostream>
#include <cstdio>
#include <cstring>
#include <algorithm>
#define MAXN 1000005

using namespace std;

int read() {
    int x=0,f=1;char ch=getchar();
    while(ch<'0' || ch>'9') {if(ch=='-')f=-1;ch=getchar();}
    while(ch>='0' && ch<='9') {x=x*10+ch-'0';ch=getchar();}
    return x*f;
}

int fa[MAXN],cnt[MAXN],ch[MAXN][2],size[MAXN],val[MAXN],tot_size,root;

void clear(int x) {
    fa[x]=ch[x][0]=ch[x][1]=cnt[x]=size[x]=val[x]=0;
}

int get(int x) {
    return ch[fa[x]][1]==x;
}

void update(int x) {
    if(x) {
        size[x]=cnt[x];
        if(ch[x][0]) size[x]+=size[ch[x][0]];
        if(ch[x][1]) size[x]+=size[ch[x][1]];
    }
}

void rotate(int x) {
    int pa=fa[x],papa=fa[pa],which=get(x);
    ch[pa][which]=ch[x][!which];fa[ch[x][!which]]=pa;
    ch[x][!which]=pa;fa[pa]=x;
    fa[x]=papa;if(papa) ch[papa][ch[papa][1]==pa]=x;
    update(pa);update(x);
}

void splay(int x) {
    for(int f;f=fa[x];rotate(x)) {
        if(fa[f]) rotate(get(f)==get(x)?f:x);
    }
    root=x;
}

void insert(int v) {
    if(root==0) {
        tot_size++;
        ch[tot_size][0]=ch[tot_size][1]=fa[tot_size]=0;
        val[tot_size]=v;
        cnt[tot_size]=size[tot_size]=1;
        root=tot_size;
        return;
    }
    int f=0,now=root;
    while(true) {
        if(val[now]==v) {
            cnt[now]++;
            update(now);
            update(f);
            splay(now);
            return;
        }
        f=now;
        now=ch[now][val[now]<v];
        if(now==0) {
            tot_size++;
            ch[tot_size][0]=ch[tot_size][1]=0;
            fa[tot_size]=f;
            val[tot_size]=v;
            cnt[tot_size]=1,size[tot_size]=1;
            ch[f][val[f]<v]=tot_size;
            update(f);
            splay(tot_size);
            return;
        }
    }
}

int find(int x) {
    int res=0,now=root;
    while(true) {    
        if(x<val[now]) {
            now=ch[now][0];
        }
        else {
            res+=size[ch[now][0]];
            if(x==val[now]) {
                splay(now);
                return res+1;
            }
            res+=cnt[now];
            now=ch[now][1];
        }
    }
}

int findx(int p) {
    int now=root;
    while(true) {
        if(ch[now][0] && p<=size[ch[now][0]]) {
            now=ch[now][0];
        }
        else {
            int temp=size[ch[now][0]]+cnt[now];
            if(p<=temp) return val[now];
            p-=temp;
            now=ch[now][1];
        }
    }
}

int pre() {
    int now=ch[root][0];
    while(ch[now][1]) now=ch[now][1];
    return now;
}

int next() {
    int now=ch[root][1];
    while(ch[now][0]) now=ch[now][0];
    return now;
}

void del(int x) {
    int gg=find(x);
    if(cnt[root]>1) {
        cnt[root]--;
        return;
    }
    if(!ch[root][0] && !ch[root][1]) {
        clear(root);
        root=0;
        return;
    }
    if(!ch[root][0]) {
        int oldroot=root;
        root=ch[root][1];
        fa[root]=0;
        clear(oldroot);
        return;
    }
    else if(!ch[root][1]) {
        int oldroot=root;
        root=ch[root][0];
        fa[root]=0;
        clear(oldroot);
        return;
    }
    int oldroot=root;
    splay(pre());
    fa[ch[oldroot][1]]=root;
    ch[root][1]=ch[oldroot][1];
    clear(oldroot);
    update(root);
    return;
}

int main() {
    int n=read();
    while(n--) {
        int opt=read(),x=read();
        if(opt==1) insert(x);
        if(opt==2) del(x);
        if(opt==3) printf("%d\n",find(x));
        if(opt==4) printf("%d\n",findx(x));
        if(opt==5) {
            insert(x);
            printf("%d\n",val[pre()]);
            del(x);
        }
        if(opt==6) {
            insert(x);
            printf("%d\n",val[next()]);
            del(x);
        }
    }
    return 0;
}
  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值