以 [NOI 2005 维护序列][1]为例,说明一下具体的splay tree 的代码编写和实现技巧。
翻转真的是一个BT的操作~~
一步一步来,首先,结构体定义:
typedef struct node{
struct node *pre,*ch[2];
int size,value,sums,maxs,maxl,maxr;
bool rev,cover;
}node;
node *root,*nill,buf[MAXN],*stk[MAXN];
int bufsize,stksize;
root指向树根节点 nill指向所有的空叶子节点
buf和stk则是用来进行节点的预分配和回收,原题目只给了64MB内存,不回收会MLE。
sums代表子区间的和,maxs代表该子区间的最大子段和,
maxl代表该子区间左起最大的累加和
maxr代表该子区间右起最大的累加和
新建空树,初始化过程如下:
node *getnode(int value,node *fa){
node *p=NULL;
if(bufsize<MAXN){
p=&buf[bufsize++];
}else if(stksize>0){
p=stk[--stksize];
if(p->ch[0]!=nill) stk[stksize++]=p->ch[0];
if(p->ch[1]!=nill) stk[stksize++]=p->ch[1];
}else {
p=(node*)malloc(sizeof(node));
} p->size=1;
p->value=p->sums=p->maxs=p->maxl=p->maxr=value;
p->rev=p->cover=false; p->pre=fa; p->ch[0]=p->ch[1]=nill;
return p;
}
int init(){
bufsize=stksize=0;
nill=getnode(-INF,NULL);
nill->ch[0]=nill->ch[1]=nill->pre=NULL;
nill->size=nill->sums=0;
root=getnode(-INF,nill);
root->ch[1]=getnode(-INF,root);
update(root);
return 0;
}
这里实现涉及了两个小技巧:
增加nill节点,不对nill进行update操作,将其sums和size置0,可以省去边界叶子节点判断的麻烦。
同时,为整个splay树增加了两个边界节点,value设为负无穷,方便区间边界的处理。
查找第POS位置的节点,并提取到根:
int select(int pos,node *fa){
node *rt=root;
while(rt!=nill){
pushdown(rt);
if(rt->ch[0]->size+1 == pos) break;
else if(rt->ch[0]->size >= pos) rt=rt->ch[0];
else {pos-=rt->ch[0]->size+1; rt=rt->ch[1];}
}
splay(rt,fa);
return 0;
}
这里,size域起到了作用,代表了该区间所有节点的总数,
注意到附加的全局叶子结点nill->size域为0,因此,nill不会影响到size域的计数。
而两个边界节点的size初始为1,是参与计数的。
查找方法很简单,判断当前节点左边孩子数量是否等于pos-1,若是此时的当前节点就是我们要找的,否则若小于pos-1,则往左走,再否则往右走。
同时每次访问前都要pushdown操作时必须的。
Splay(rt,fa),将rt伸展到指定的 祖先 fa下面:
int pushdown(node *x){
if(x==nill || x==NULL) return 0;
if(x->rev){
x->rev=false;
x->ch[0]->rev = !x->ch[0]->rev;
x->ch[1]->rev = !x->ch[1]->rev;
node *tmp=x->ch[0]; x->ch[0]=x->ch[1]; x->ch[1]=tmp;
int t=x->maxl; x->maxl=x->maxr; x->maxr=t;
}
if(x->cover){
x->cover=false;
x->ch[0]->cover= x->ch[1]->cover=true;
x->ch[0]->value= x->ch[1]->value=x->value;
x->maxl=x->maxr=x->maxs=x->sums=x->value*x->size;
if(x->value<0) x->maxl=x->maxr=x->maxs=x->value;
}
return 0;
}
int update(node *x){
if(x==nill || x==NULL) return 0;
pushdown(x->ch[0]); pushdown(x->ch[1]); //must pushdown when you visited it
x->size = x->ch[0]->size + x->ch[1]->size + 1;
x->sums = x->ch[0]->sums + x->ch[1]->sums + x->value;
x->maxl = max(x->ch[0]->maxl,x->ch[0]->sums+x->value);
x->maxl = max(x->maxl,x->ch[0]->sums+x->value+x->ch[1]->maxl);
x->maxl = max(x->maxl,x->ch[0]->sums+x->value+x->ch[1]->sums);
x->maxr = max(x->ch[1]->maxr,x->ch[1]->sums+x->value);
x->maxr = max(x->maxr,x->ch[1]->sums+x->value+x->ch[0]->maxr);
x->maxr = max(x->maxr,x->ch[1]->sums+x->value+x->ch[0]->sums);
x->maxs = max(x->value, x->ch[0]->maxs);
x->maxs = max(x->maxs, x->ch[1]->maxs);
x->maxs = max(x->maxs, x->ch[0]->maxr+x->value);
x->maxs = max(x->maxs, x->ch[1]->maxl+x->value);
x->maxs = max(x->maxs, x->ch[0]->maxr + x->value + x->ch[1]->maxl);
return 0;
}
int rotateto(node *x,int to){
node *xp=x->pre;
//
//pushdown(xp); pushdown(x);
xp->ch[to^1] = x->ch[to];
xp->ch[to^1]->pre = xp;
x->pre = xp->pre;
if(xp->pre->ch[0] == xp) xp->pre->ch[0]=x;
else xp->pre->ch[1]=x;
xp->pre=x; x->ch[to]=xp;
update(xp);//update(x) is not needed
if(xp == root) root=x;
return 0;
}
int splay(node *x,node *fa){
pushdown(x);//this is must needed,note : cover && reverse
while(x->pre!=fa){
if(x->pre->pre == fa){
if(x->pre->ch[0] == x) rotateto(x,1);
else rotateto(x,0);
}else{
if(x->pre->pre->ch[0]==x->pre){
if(x->pre->ch[0] == x) {rotateto(x->pre,1);rotateto(x,1);}
else {rotateto(x,0);rotateto(x,1);}
}else{
if(x->pre->ch[1] == x) {rotateto(x->pre,0);rotateto(x,0);}
else {rotateto(x,1);rotateto(x,0);}
}
}
}
update(x);
return 0;
}
其中pushdown和update操作的功能是明确区分开的:
- pushdown负责将当前节点的标记落实,并传递给孩子节点,最后清除自己的标记信息。
- update负责根据左右孩子节点(一定是从左右孩子处获取最新信息,因为旋转过程会使得左右孩子发生变化)更新自己。
这里就是splay的核心操作,旋转操作将x节点不断旋转到祖先fa下面,
注意到:
- 每次旋转都会将X提高一层,x的父节点xp随之下降一层
- 每次提升x,必须update(xp),因为其左右孩子发生了变化。
- update(x)可以在所有rotateto之后执行
- splay方法结束时,无需update(fa)
- 每次select操作都会从root到x的路径上执行pushdown,所以select内部的splay、rotateto操作中pushdown可以节省掉
- 同时splay则对应了一条upadate路径,select的最后调用了splay(),这意味着select结束时,从root出发的某一条路径是没有lazy
tag的。- 这意味着,当我们select之后,再次访问(只读取,不修改)这条路径上的某个节点时,无需再次进行pushdown。
- 但是,对于7,如果我们不仅访问,并且修改了这条路径上的某个节点,此时splay该节点,则必须进行pushdown。
其余的操作,插入,更新,删除,等:
int insert(int pos,int tot){
node *p,*q;
int s;
scanf("%d",&s);
p=q=getnode(s,nill);
for(int i=1;i<tot;i++){
scanf("%d",&s);
p=p->ch[1]=getnode(s,p);
}
select(pos+1,nill);
select(pos+2,root);
root->ch[1]->ch[0]=q;
q->pre=root->ch[1];
splay(p,nill);
return 0;
}
int insert(int pos,int tot,int *s){
node *p,*q;
p=q=getnode(s[0],nill);
for(int i=1;i<tot;i++){
p=p->ch[1]=getnode(s[i],p);
}
select(pos+1,nill);
select(pos+2,root);
root->ch[1]->ch[0]=q;
q->pre=root->ch[1];
splay(p,nill);
return 0;
}
int remove(int pos,int tot){
select(pos,nill);
select(pos+tot+1,root);
if(root->ch[1]->ch[0]!=nill) stk[stksize++]=root->ch[1]->ch[0];
root->ch[1]->ch[0]=nill;
splay(root->ch[1],nill);
return 0;
}
int reverse(int pos,int tot){
select(pos,nill);
select(pos+tot+1,root);
root->ch[1]->ch[0]->rev=!root->ch[1]->ch[0]->rev;
pushdown(root->ch[1]->ch[0]);
//need pushdown first
splay(root->ch[1]->ch[0],nill);
return 0;
}
int cover(int pos,int tot,int c){
select(pos,nill);
select(pos+tot+1,root);
root->ch[1]->ch[0]->cover=true;
root->ch[1]->ch[0]->value=c;
pushdown(root->ch[1]->ch[0]);
//need pushdown first
splay(root->ch[1]->ch[0],nill);
return 0;
}
int getsums(int pos,int tot){
select(pos,nill);
select(pos+tot+1,root);
//pushdown(root->ch[1]->ch[0]);
//update(root->ch[1]->ch[0]);
return root->ch[1]->ch[0]->sums;
}
int maxsum(){
//select(1,nill);
//select(root->size,root);
pushdown(root);
update(root);
return root->maxs;
}
区间操作的前提都是一样的,首先通过旋转操作,
将目标区间转移到root->ch[1]->ch[0]位置。
插入操作很简单,先建立一条链,挂到root->ch[1]->ch[0]位置,
然后将链末尾节点splay到root。
注意到:
- 提取区间的两次select操作保证了root->ch[1]->ch[0]这条链是update过的,因此splay之前不需要pushdown操作
- 同时,新建立的链没有标记信息,也不需要pushdown操作
- 这样,只需要对链末尾节点执行一次splay(x,root),就会在链末节点不断上升的过程中,一次更新整条链的所有节点(包括了root和root->ch[0])
- 这里3代表了前面的第7点,即对于select方法之外的splay调用而言,其内部的pushdown仍然是可以节省的
- 对于cover和reverse操作,因为select提取区间之后,执行了修改操作,splay时必须先pushdown,对应前面第8点。
到这里全部说完了这道题目的关键点。
可以看出,select,insert,splay,rotateto基本是通用的模板,关键在于update和pushdown操作的设计要根据具体的题目。
全部代码:
#define __LOCAL__DEBUG__
#include <cstdlib>
#include <cctype>
#include <cstring>
#include <cstdio>
#include <cmath>
#include <algorithm>
#include <vector>
#include <string>
#include <iostream>
#include <sstream>
#include <map>
#include <set>
#include <queue>
#include <stack>
#include <fstream>
#include <numeric>
#include <iomanip>
#include <bitset>
#include <list>
#include <stdexcept>
#include <functional>
#include <utility>
#include <ctime>
using namespace std;
#define PB push_back
#define MP make_pair
#define REP(i,n) for(i=0;i<(n);++i)
#define FOR(i,l,h) for(i=(l);i<=(h);++i)
#define FORD(i,h,l) for(i=(h);i>=(l);--i)
#define LEFT 0
#define RIGHT 1
typedef vector<int> VI;
typedef vector<string> VS;
typedef vector<double> VD;
typedef long long LL64;
typedef unsigned long long LL65;
typedef pair<int,int> PII;
#define MAXN 510000
#define INF 1001
typedef struct splaytree{
typedef struct node{
struct node *pre,*ch[2];
int size,value,sums,maxs,maxl,maxr;
bool rev,cover;
}node;
node *root,*nill,buf[MAXN],*stk[MAXN];
int bufsize,stksize;
node *getnode(int value,node *fa){
node *p=NULL;
if(bufsize<MAXN){
p=&buf[bufsize++];
}else if(stksize>0){
p=stk[--stksize];
if(p->ch[0]!=nill) stk[stksize++]=p->ch[0];
if(p->ch[1]!=nill) stk[stksize++]=p->ch[1];
}else {
p=(node*)malloc(sizeof(node));
} p->size=1;
p->value=p->sums=p->maxs=p->maxl=p->maxr=value;
p->rev=p->cover=false; p->pre=fa; p->ch[0]=p->ch[1]=nill;
return p;
}
int pushdown(node *x){
if(x==nill || x==NULL) return 0;
if(x->rev){
x->rev=false;
x->ch[0]->rev = !x->ch[0]->rev;
x->ch[1]->rev = !x->ch[1]->rev;
node *tmp=x->ch[0]; x->ch[0]=x->ch[1]; x->ch[1]=tmp;
int t=x->maxl; x->maxl=x->maxr; x->maxr=t;
}
if(x->cover){
x->cover=false;
x->ch[0]->cover= x->ch[1]->cover=true;
x->ch[0]->value= x->ch[1]->value=x->value;
x->maxl=x->maxr=x->maxs=x->sums=x->value*x->size;
if(x->value<0) x->maxl=x->maxr=x->maxs=x->value;
}
return 0;
}
int update(node *x){
if(x==nill || x==NULL) return 0;
pushdown(x->ch[0]); pushdown(x->ch[1]); //must pushdown when you visited it
x->size = x->ch[0]->size + x->ch[1]->size + 1;
x->sums = x->ch[0]->sums + x->ch[1]->sums + x->value;
x->maxl = max(x->ch[0]->maxl,x->ch[0]->sums+x->value);
x->maxl = max(x->maxl,x->ch[0]->sums+x->value+x->ch[1]->maxl);
x->maxl = max(x->maxl,x->ch[0]->sums+x->value+x->ch[1]->sums);
x->maxr = max(x->ch[1]->maxr,x->ch[1]->sums+x->value);
x->maxr = max(x->maxr,x->ch[1]->sums+x->value+x->ch[0]->maxr);
x->maxr = max(x->maxr,x->ch[1]->sums+x->value+x->ch[0]->sums);
x->maxs = max(x->value, x->ch[0]->maxs);
x->maxs = max(x->maxs, x->ch[1]->maxs);
x->maxs = max(x->maxs, x->ch[0]->maxr+x->value);
x->maxs = max(x->maxs, x->ch[1]->maxl+x->value);
x->maxs = max(x->maxs, x->ch[0]->maxr + x->value + x->ch[1]->maxl);
return 0;
}
int init(){
bufsize=stksize=0;
nill=getnode(-INF,NULL);
nill->ch[0]=nill->ch[1]=nill->pre=NULL;
nill->size=nill->sums=0;
root=getnode(-INF,nill);
root->ch[1]=getnode(-INF,root);
update(root);
return 0;
}
int rotateto(node *x,int to){
node *xp=x->pre;
//this is not must needed : because
//pushdown(xp); pushdown(x);
xp->ch[to^1] = x->ch[to];
xp->ch[to^1]->pre = xp;
x->pre = xp->pre;
if(xp->pre->ch[0] == xp) xp->pre->ch[0]=x;
else xp->pre->ch[1]=x;
xp->pre=x; x->ch[to]=xp;
update(xp);
if(xp == root) root=x;
return 0;
}
int splay(node *x,node *fa){
pushdown(x);//is this must needed ?
while(x->pre!=fa){
if(x->pre->pre == fa){
if(x->pre->ch[0] == x) rotateto(x,1);
else rotateto(x,0);
}else{
if(x->pre->pre->ch[0]==x->pre){
if(x->pre->ch[0] == x) {rotateto(x->pre,1);rotateto(x,1);}
else {rotateto(x,0);rotateto(x,1);}
}else{
if(x->pre->ch[1] == x) {rotateto(x->pre,0);rotateto(x,0);}
else {rotateto(x,1);rotateto(x,0);}
}
}
}
update(x);//update(x->pre) is not needed
return 0;
}
int select(int pos,node *fa){
node *rt=root;
while(rt!=nill){
pushdown(rt);
if(rt->ch[0]->size+1 == pos) break;
else if(rt->ch[0]->size >= pos) rt=rt->ch[0];
else {pos-=rt->ch[0]->size+1; rt=rt->ch[1];}
}
splay(rt,fa);
return 0;
}
int insert(int pos,int tot){
node *p,*q;
int s;
scanf("%d",&s);
p=q=getnode(s,nill);
for(int i=1;i<tot;i++){
scanf("%d",&s);
p=p->ch[1]=getnode(s,p);
}
select(pos+1,nill);
select(pos+2,root);
root->ch[1]->ch[0]=q;
q->pre=root->ch[1];
splay(p,nill);
return 0;
}
int insert(int pos,int tot,int *s){
node *p,*q;
p=q=getnode(s[0],nill);
for(int i=1;i<tot;i++){
p=p->ch[1]=getnode(s[i],p);
}
select(pos+1,nill);
select(pos+2,root);
root->ch[1]->ch[0]=q;
q->pre=root->ch[1];
splay(p,nill);
return 0;
}
int remove(int pos,int tot){
select(pos,nill);
select(pos+tot+1,root);
if(root->ch[1]->ch[0]!=nill) stk[stksize++]=root->ch[1]->ch[0];
root->ch[1]->ch[0]=nill;
splay(root->ch[1],nill);
return 0;
}
int reverse(int pos,int tot){
select(pos,nill);
select(pos+tot+1,root);
root->ch[1]->ch[0]->rev=!root->ch[1]->ch[0]->rev;
splay(root->ch[1]->ch[0],nill);
return 0;
}
int cover(int pos,int tot,int c){
select(pos,nill);
select(pos+tot+1,root);
root->ch[1]->ch[0]->cover=true;
root->ch[1]->ch[0]->value=c;
splay(root->ch[1]->ch[0],nill);
return 0;
}
int getsums(int pos,int tot){
select(pos,nill);
select(pos+tot+1,root);
//pushdown(root->ch[1]->ch[0]);
//update(root->ch[1]->ch[0]);
return root->ch[1]->ch[0]->sums;
}
int maxsum(){
//select(1,nill);
//select(root->size,root);
pushdown(root);
update(root);
return root->maxs;
}
}splaytree;
int main(){
#ifndef __LOCAL__DEBUG__
std::ios::sync_with_stdio(false);
std::cin.tie(0);
#else
freopen("sequence3.in","r",stdin);
freopen("out.txt","w",stdout);
#endif
int i,n,m,c;
int pos,tot;
splaytree *stree=new splaytree;
stree->init();
char s[30];
scanf("%d%d",&n,&m);
stree->insert(0,n);
for(i=0;i<m;i++){
scanf("%s",s);
if(s[0]=='I'){
scanf("%d%d",&pos,&tot);
stree->insert(pos,tot);
}else if(s[0]=='D'){
scanf("%d%d",&pos,&tot);
stree->remove(pos,tot);
}else if(s[0]=='R'){
scanf("%d%d",&pos,&tot);
stree->reverse(pos,tot);
}else if(s[0]=='G'){
scanf("%d%d",&pos,&tot);
printf("%d\n",stree->getsums(pos,tot));
}else if(s[2]=='K'){
scanf("%d%d%d",&pos,&tot,&c);
stree->cover(pos,tot,c);
}else {
printf("%d\n",stree->maxsum());
}
}
delete stree;
return 0;
}
[1]:http://www.lydsy.com/JudgeOnline/problem.php?id=1500
扫码关注作者,定期分享技术、算法类文章