初识splay tree (三)

以 [NOI 2005 维护序列][1]为例,说明一下具体的splay tree 的代码编写和实现技巧。
这里写图片描述
这里写图片描述
这里写图片描述
这里写图片描述
这里写图片描述

翻转真的是一个BT的操作~~
一步一步来,首先,结构体定义:

typedef struct node{
        struct node *pre,*ch[2];
        int size,value,sums,maxs,maxl,maxr;
        bool rev,cover;
}node;
node *root,*nill,buf[MAXN],*stk[MAXN];
int bufsize,stksize;

root指向树根节点 nill指向所有的空叶子节点
buf和stk则是用来进行节点的预分配和回收,原题目只给了64MB内存,不回收会MLE。
sums代表子区间的和,maxs代表该子区间的最大子段和,
maxl代表该子区间左起最大的累加和
maxr代表该子区间右起最大的累加和

新建空树,初始化过程如下:

node *getnode(int value,node *fa){
        node *p=NULL;
        if(bufsize<MAXN){
            p=&buf[bufsize++];
        }else if(stksize>0){
            p=stk[--stksize];
            if(p->ch[0]!=nill) stk[stksize++]=p->ch[0];
            if(p->ch[1]!=nill) stk[stksize++]=p->ch[1];
        }else {
            p=(node*)malloc(sizeof(node));
        } p->size=1;
        p->value=p->sums=p->maxs=p->maxl=p->maxr=value;
        p->rev=p->cover=false; p->pre=fa; p->ch[0]=p->ch[1]=nill;
        return p;
}
int init(){
        bufsize=stksize=0;
        nill=getnode(-INF,NULL);
        nill->ch[0]=nill->ch[1]=nill->pre=NULL;
        nill->size=nill->sums=0;
        root=getnode(-INF,nill);
        root->ch[1]=getnode(-INF,root);
        update(root);
        return 0;
}

这里实现涉及了两个小技巧:
增加nill节点,不对nill进行update操作,将其sums和size置0,可以省去边界叶子节点判断的麻烦。
同时,为整个splay树增加了两个边界节点,value设为负无穷,方便区间边界的处理。

查找第POS位置的节点,并提取到根:

int select(int pos,node *fa){
        node *rt=root;
        while(rt!=nill){
            pushdown(rt);
            if(rt->ch[0]->size+1 == pos) break;
            else if(rt->ch[0]->size >= pos) rt=rt->ch[0];
            else {pos-=rt->ch[0]->size+1; rt=rt->ch[1];}
        }
        splay(rt,fa);
        return 0;
    }

这里,size域起到了作用,代表了该区间所有节点的总数,
注意到附加的全局叶子结点nill->size域为0,因此,nill不会影响到size域的计数。
而两个边界节点的size初始为1,是参与计数的。
查找方法很简单,判断当前节点左边孩子数量是否等于pos-1,若是此时的当前节点就是我们要找的,否则若小于pos-1,则往左走,再否则往右走。
同时每次访问前都要pushdown操作时必须的。

Splay(rt,fa),将rt伸展到指定的 祖先 fa下面:

int pushdown(node *x){
        if(x==nill || x==NULL) return 0;
        if(x->rev){
            x->rev=false;
            x->ch[0]->rev = !x->ch[0]->rev;
            x->ch[1]->rev = !x->ch[1]->rev;
            node *tmp=x->ch[0]; x->ch[0]=x->ch[1]; x->ch[1]=tmp;
            int t=x->maxl; x->maxl=x->maxr; x->maxr=t;
        }
        if(x->cover){
            x->cover=false;
            x->ch[0]->cover= x->ch[1]->cover=true;
            x->ch[0]->value= x->ch[1]->value=x->value;
            x->maxl=x->maxr=x->maxs=x->sums=x->value*x->size;
            if(x->value<0) x->maxl=x->maxr=x->maxs=x->value;
        }
        return 0;
}
    int update(node *x){
        if(x==nill || x==NULL) return 0;
        pushdown(x->ch[0]); pushdown(x->ch[1]); //must pushdown when you visited it
        x->size = x->ch[0]->size + x->ch[1]->size + 1;
        x->sums = x->ch[0]->sums + x->ch[1]->sums + x->value;
        x->maxl = max(x->ch[0]->maxl,x->ch[0]->sums+x->value);
        x->maxl = max(x->maxl,x->ch[0]->sums+x->value+x->ch[1]->maxl);
        x->maxl = max(x->maxl,x->ch[0]->sums+x->value+x->ch[1]->sums);
        x->maxr = max(x->ch[1]->maxr,x->ch[1]->sums+x->value);
        x->maxr = max(x->maxr,x->ch[1]->sums+x->value+x->ch[0]->maxr);
        x->maxr = max(x->maxr,x->ch[1]->sums+x->value+x->ch[0]->sums);
        x->maxs = max(x->value, x->ch[0]->maxs);
        x->maxs = max(x->maxs, x->ch[1]->maxs);
        x->maxs = max(x->maxs, x->ch[0]->maxr+x->value);
        x->maxs = max(x->maxs, x->ch[1]->maxl+x->value);
        x->maxs = max(x->maxs, x->ch[0]->maxr + x->value + x->ch[1]->maxl);
        return 0;
}
int rotateto(node *x,int to){
        node *xp=x->pre;
        //
        //pushdown(xp); pushdown(x);
        xp->ch[to^1] = x->ch[to];
        xp->ch[to^1]->pre = xp;
        x->pre = xp->pre;
        if(xp->pre->ch[0] == xp) xp->pre->ch[0]=x;
        else xp->pre->ch[1]=x;
        xp->pre=x; x->ch[to]=xp;
        update(xp);//update(x) is not needed
        if(xp == root) root=x;
        return 0;
    }
    int splay(node *x,node *fa){
        pushdown(x);//this is must needed,note : cover && reverse
        while(x->pre!=fa){
            if(x->pre->pre == fa){
                if(x->pre->ch[0] == x) rotateto(x,1);
                else rotateto(x,0);
            }else{
                if(x->pre->pre->ch[0]==x->pre){
                    if(x->pre->ch[0] == x) {rotateto(x->pre,1);rotateto(x,1);}
                    else {rotateto(x,0);rotateto(x,1);}
                }else{
                    if(x->pre->ch[1] == x) {rotateto(x->pre,0);rotateto(x,0);}
                    else {rotateto(x,1);rotateto(x,0);}
                }
            }
        }
        update(x);
        return 0;
}

其中pushdown和update操作的功能是明确区分开的:

  1. pushdown负责将当前节点的标记落实,并传递给孩子节点,最后清除自己的标记信息。
  2. update负责根据左右孩子节点(一定是从左右孩子处获取最新信息,因为旋转过程会使得左右孩子发生变化)更新自己。

这里就是splay的核心操作,旋转操作将x节点不断旋转到祖先fa下面,
注意到:

  1. 每次旋转都会将X提高一层,x的父节点xp随之下降一层
  2. 每次提升x,必须update(xp),因为其左右孩子发生了变化。
  3. update(x)可以在所有rotateto之后执行
  4. splay方法结束时,无需update(fa)
  5. 每次select操作都会从root到x的路径上执行pushdown,所以select内部的splay、rotateto操作中pushdown可以节省掉
  6. 同时splay则对应了一条upadate路径,select的最后调用了splay(),这意味着select结束时,从root出发的某一条路径是没有lazy
    tag的。
  7. 这意味着,当我们select之后,再次访问(只读取,不修改)这条路径上的某个节点时,无需再次进行pushdown。
  8. 但是,对于7,如果我们不仅访问,并且修改了这条路径上的某个节点,此时splay该节点,则必须进行pushdown。

其余的操作,插入,更新,删除,等:

int insert(int pos,int tot){
        node *p,*q;
        int s;
        scanf("%d",&s);
        p=q=getnode(s,nill);
        for(int i=1;i<tot;i++){
            scanf("%d",&s);
            p=p->ch[1]=getnode(s,p);
        }
        select(pos+1,nill);
        select(pos+2,root);
        root->ch[1]->ch[0]=q;
        q->pre=root->ch[1];
        splay(p,nill);
        return 0;
    }
    int insert(int pos,int tot,int *s){
        node *p,*q;
        p=q=getnode(s[0],nill);
        for(int i=1;i<tot;i++){
            p=p->ch[1]=getnode(s[i],p);
        }
        select(pos+1,nill);
        select(pos+2,root);
        root->ch[1]->ch[0]=q;
        q->pre=root->ch[1];
        splay(p,nill);
        return 0;
    }
    int remove(int pos,int tot){
        select(pos,nill);
        select(pos+tot+1,root);
        if(root->ch[1]->ch[0]!=nill) stk[stksize++]=root->ch[1]->ch[0];
        root->ch[1]->ch[0]=nill;
        splay(root->ch[1],nill);
        return 0;
    }
    int reverse(int pos,int tot){
        select(pos,nill);
        select(pos+tot+1,root);
        root->ch[1]->ch[0]->rev=!root->ch[1]->ch[0]->rev;
        pushdown(root->ch[1]->ch[0]);
        //need pushdown first
        splay(root->ch[1]->ch[0],nill);
        return 0;
    }
    int cover(int pos,int tot,int c){
        select(pos,nill);
        select(pos+tot+1,root);
        root->ch[1]->ch[0]->cover=true;
        root->ch[1]->ch[0]->value=c;
        pushdown(root->ch[1]->ch[0]);
        //need pushdown first
        splay(root->ch[1]->ch[0],nill);

        return 0;
    }
    int getsums(int pos,int tot){
        select(pos,nill);
        select(pos+tot+1,root);
        //pushdown(root->ch[1]->ch[0]);
        //update(root->ch[1]->ch[0]);
        return root->ch[1]->ch[0]->sums;
    }
    int maxsum(){
        //select(1,nill);
        //select(root->size,root);
        pushdown(root);
        update(root);
        return root->maxs;
    }
区间操作的前提都是一样的,首先通过旋转操作,
将目标区间转移到root->ch[1]->ch[0]位置。

插入操作很简单,先建立一条链,挂到root->ch[1]->ch[0]位置,
然后将链末尾节点splay到root。
注意到:

  1. 提取区间的两次select操作保证了root->ch[1]->ch[0]这条链是update过的,因此splay之前不需要pushdown操作
  2. 同时,新建立的链没有标记信息,也不需要pushdown操作
  3. 这样,只需要对链末尾节点执行一次splay(x,root),就会在链末节点不断上升的过程中,一次更新整条链的所有节点(包括了root和root->ch[0])
  4. 这里3代表了前面的第7点,即对于select方法之外的splay调用而言,其内部的pushdown仍然是可以节省的
  5. 对于cover和reverse操作,因为select提取区间之后,执行了修改操作,splay时必须先pushdown,对应前面第8点。

到这里全部说完了这道题目的关键点。
可以看出,select,insert,splay,rotateto基本是通用的模板,关键在于update和pushdown操作的设计要根据具体的题目。

全部代码:

#define __LOCAL__DEBUG__

#include <cstdlib>
#include <cctype>
#include <cstring>
#include <cstdio>
#include <cmath>
#include <algorithm>
#include <vector>
#include <string>
#include <iostream>
#include <sstream>
#include <map>
#include <set>
#include <queue>
#include <stack>
#include <fstream>
#include <numeric>
#include <iomanip>
#include <bitset>
#include <list>
#include <stdexcept>
#include <functional>
#include <utility>
#include <ctime>
using namespace std;

#define PB push_back
#define MP make_pair

#define REP(i,n) for(i=0;i<(n);++i)
#define FOR(i,l,h) for(i=(l);i<=(h);++i)
#define FORD(i,h,l) for(i=(h);i>=(l);--i)

#define LEFT    0
#define RIGHT   1
typedef vector<int> VI;
typedef vector<string> VS;
typedef vector<double> VD;
typedef long long LL64;
typedef unsigned long long LL65;
typedef pair<int,int> PII;
#define MAXN 510000
#define INF 1001

typedef struct splaytree{
    typedef struct node{
        struct node *pre,*ch[2];
        int size,value,sums,maxs,maxl,maxr;
        bool rev,cover;
    }node;
    node *root,*nill,buf[MAXN],*stk[MAXN];
    int bufsize,stksize;
    node *getnode(int value,node *fa){
        node *p=NULL;
        if(bufsize<MAXN){
            p=&buf[bufsize++];
        }else if(stksize>0){
            p=stk[--stksize];
            if(p->ch[0]!=nill) stk[stksize++]=p->ch[0];
            if(p->ch[1]!=nill) stk[stksize++]=p->ch[1];
        }else {
            p=(node*)malloc(sizeof(node));
        } p->size=1;
        p->value=p->sums=p->maxs=p->maxl=p->maxr=value;
        p->rev=p->cover=false; p->pre=fa; p->ch[0]=p->ch[1]=nill;
        return p;
    }
    int pushdown(node *x){
        if(x==nill || x==NULL) return 0;
        if(x->rev){
            x->rev=false;
            x->ch[0]->rev = !x->ch[0]->rev;
            x->ch[1]->rev = !x->ch[1]->rev;
            node *tmp=x->ch[0]; x->ch[0]=x->ch[1]; x->ch[1]=tmp;
            int t=x->maxl; x->maxl=x->maxr; x->maxr=t;
        }
        if(x->cover){
            x->cover=false;
            x->ch[0]->cover= x->ch[1]->cover=true;
            x->ch[0]->value= x->ch[1]->value=x->value;
            x->maxl=x->maxr=x->maxs=x->sums=x->value*x->size;
            if(x->value<0) x->maxl=x->maxr=x->maxs=x->value;
        }
        return 0;
    }
    int update(node *x){
        if(x==nill || x==NULL) return 0;
        pushdown(x->ch[0]); pushdown(x->ch[1]); //must pushdown when you visited it
        x->size = x->ch[0]->size + x->ch[1]->size + 1;
        x->sums = x->ch[0]->sums + x->ch[1]->sums + x->value;
        x->maxl = max(x->ch[0]->maxl,x->ch[0]->sums+x->value);
        x->maxl = max(x->maxl,x->ch[0]->sums+x->value+x->ch[1]->maxl);
        x->maxl = max(x->maxl,x->ch[0]->sums+x->value+x->ch[1]->sums);
        x->maxr = max(x->ch[1]->maxr,x->ch[1]->sums+x->value);
        x->maxr = max(x->maxr,x->ch[1]->sums+x->value+x->ch[0]->maxr);
        x->maxr = max(x->maxr,x->ch[1]->sums+x->value+x->ch[0]->sums);
        x->maxs = max(x->value, x->ch[0]->maxs);
        x->maxs = max(x->maxs, x->ch[1]->maxs);
        x->maxs = max(x->maxs, x->ch[0]->maxr+x->value);
        x->maxs = max(x->maxs, x->ch[1]->maxl+x->value);
        x->maxs = max(x->maxs, x->ch[0]->maxr + x->value + x->ch[1]->maxl);
        return 0;
    }
    int init(){
        bufsize=stksize=0;
        nill=getnode(-INF,NULL);
        nill->ch[0]=nill->ch[1]=nill->pre=NULL;
        nill->size=nill->sums=0;
        root=getnode(-INF,nill);
        root->ch[1]=getnode(-INF,root);
        update(root);
        return 0;
    }
    int rotateto(node *x,int to){
        node *xp=x->pre;
        //this is not must needed : because
        //pushdown(xp); pushdown(x);
        xp->ch[to^1] = x->ch[to];
        xp->ch[to^1]->pre = xp;
        x->pre = xp->pre;
        if(xp->pre->ch[0] == xp) xp->pre->ch[0]=x;
        else xp->pre->ch[1]=x;
        xp->pre=x; x->ch[to]=xp;
        update(xp);
        if(xp == root) root=x;
        return 0;
    }
    int splay(node *x,node *fa){
        pushdown(x);//is this must needed ?
        while(x->pre!=fa){
            if(x->pre->pre == fa){
                if(x->pre->ch[0] == x) rotateto(x,1);
                else rotateto(x,0);
            }else{
                if(x->pre->pre->ch[0]==x->pre){
                    if(x->pre->ch[0] == x) {rotateto(x->pre,1);rotateto(x,1);}
                    else {rotateto(x,0);rotateto(x,1);}
                }else{
                    if(x->pre->ch[1] == x) {rotateto(x->pre,0);rotateto(x,0);}
                    else {rotateto(x,1);rotateto(x,0);}
                }
            }
        }
        update(x);//update(x->pre) is not needed
        return 0;
    }
    int select(int pos,node *fa){
        node *rt=root;
        while(rt!=nill){
            pushdown(rt);
            if(rt->ch[0]->size+1 == pos) break;
            else if(rt->ch[0]->size >= pos) rt=rt->ch[0];
            else {pos-=rt->ch[0]->size+1; rt=rt->ch[1];}
        }
        splay(rt,fa);
        return 0;
    }
    int insert(int pos,int tot){
        node *p,*q;
        int s;
        scanf("%d",&s);
        p=q=getnode(s,nill);
        for(int i=1;i<tot;i++){
            scanf("%d",&s);
            p=p->ch[1]=getnode(s,p);
        }
        select(pos+1,nill);
        select(pos+2,root);
        root->ch[1]->ch[0]=q;
        q->pre=root->ch[1];
        splay(p,nill);
        return 0;
    }
    int insert(int pos,int tot,int *s){
        node *p,*q;
        p=q=getnode(s[0],nill);
        for(int i=1;i<tot;i++){
            p=p->ch[1]=getnode(s[i],p);
        }
        select(pos+1,nill);
        select(pos+2,root);
        root->ch[1]->ch[0]=q;
        q->pre=root->ch[1];
        splay(p,nill);
        return 0;
    }
    int remove(int pos,int tot){
        select(pos,nill);
        select(pos+tot+1,root);
        if(root->ch[1]->ch[0]!=nill) stk[stksize++]=root->ch[1]->ch[0];
        root->ch[1]->ch[0]=nill;
        splay(root->ch[1],nill);
        return 0;
    }
    int reverse(int pos,int tot){
        select(pos,nill);
        select(pos+tot+1,root);
        root->ch[1]->ch[0]->rev=!root->ch[1]->ch[0]->rev;
        splay(root->ch[1]->ch[0],nill);
        return 0;
    }
    int cover(int pos,int tot,int c){
        select(pos,nill);
        select(pos+tot+1,root);
        root->ch[1]->ch[0]->cover=true;
        root->ch[1]->ch[0]->value=c;
        splay(root->ch[1]->ch[0],nill);
        return 0;
    }
    int getsums(int pos,int tot){
        select(pos,nill);
        select(pos+tot+1,root);
        //pushdown(root->ch[1]->ch[0]);
        //update(root->ch[1]->ch[0]);
        return root->ch[1]->ch[0]->sums;
    }
    int maxsum(){
        //select(1,nill);
        //select(root->size,root);
        pushdown(root);
        update(root);
        return root->maxs;
    }
}splaytree;

int main(){
#ifndef __LOCAL__DEBUG__
    std::ios::sync_with_stdio(false);
    std::cin.tie(0);
#else
    freopen("sequence3.in","r",stdin);
    freopen("out.txt","w",stdout);
#endif
    int i,n,m,c;
    int pos,tot;
    splaytree *stree=new splaytree;
    stree->init();
    char s[30];

    scanf("%d%d",&n,&m);
    stree->insert(0,n);
    for(i=0;i<m;i++){
        scanf("%s",s);
        if(s[0]=='I'){
            scanf("%d%d",&pos,&tot);
            stree->insert(pos,tot);
        }else if(s[0]=='D'){
            scanf("%d%d",&pos,&tot);
            stree->remove(pos,tot);
        }else if(s[0]=='R'){
            scanf("%d%d",&pos,&tot);
            stree->reverse(pos,tot);
        }else if(s[0]=='G'){
            scanf("%d%d",&pos,&tot);
            printf("%d\n",stree->getsums(pos,tot));
        }else if(s[2]=='K'){
            scanf("%d%d%d",&pos,&tot,&c);
            stree->cover(pos,tot,c);
        }else {
            printf("%d\n",stree->maxsum());
        }
    }
    delete stree;
    return 0;
}

[1]:http://www.lydsy.com/JudgeOnline/problem.php?id=1500


扫码关注作者,定期分享技术、算法类文章
这里写图片描述

  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值