Splay中有一些常用的操作和一些很容易犯的bug
先来记录一下容易写错的地方
- 新建结点(或者连续一段,特指在key_value的位置插入和删除)或者删除结点(或者连续一段)的时候需要pushup(ch[root][1]),pushup(root)
- 申请一个新的结点的时候注意更新与这个点所有相关的变量,比如是维护最小值,不仅要更新val[now],还要更新minv[now]
- 如果有区间加的时候,你需要查询Splay上某个点的值,你不能直接输出val[now],应该先把这个点旋转到根
- 如果需要知道某个值在Splay中的位置,可以用一个map映射
- 调用get_kth,get_min,get_max这些函数时,要把调用的这个点旋转到根
- 插入的时候要进行旋转,这样才能保证时间复杂度logn
接下来是一些常用的函数
1.Splay的构造和初始化
void build(int &now,int l,int r,int father){
if(l>r)
return ;
int m=(l+r)>>1;
Newnode(now,father,a[m]);
build(ch[now][0],l,m-1,now);
build(ch[now][1],m+1,r,now);
pushup(now);
}
void init(){
root=tot=0;
Newnode(root,0,-INF);
Newnode(ch[root][1],root,-INF);//头尾各加入一个空位
build(key_value,1,n,ch[root][1]);
pushup(ch[root][1]);
pushup(root);
}
2.pushup和pushdown
//addv表示区间加,rev表示区间翻转
void pushdown(int x){
int ls=ch[x][0],rs=ch[x][1];
if(addv[x]){
if(ls)
val[ls]+=addv[x],addv[ls]+=addv[x];
if(rs)
val[rs]+=addv[x],addv[rs]+=addv[x];
addv[x]=0;
}
if(rev[x]){
if(ls) rev[ls]^=1;
if(rs) rev[rs]^=1;
swap(ch[x][0],ch[x][1]);
rev[x]=0;
}
}
void pushup(int x){
size[x]=size[ch[x][0]]+size[ch[x][1]]+1;
}
3.Splay主体函数
//旋转,kind为1为右旋,kind为0为左旋
void Rotate(int x,int kind){
int y=fa[x];
ch[y][!kind]=ch[x][kind];
fa[ch[x][kind]]=y;
//如果父节点不是根结点,则要和父节点的父节点连接起来
if(fa[y])
ch[fa[y]][ch[fa[y]][1]==y]=x;
fa[x]=fa[y];
ch[x][kind]=y;
fa[y]=x;
pushup(y);
}
//Splay调整,将根为now的子树调整为goal
void Splay(int now,int goal){
pushdown(now);
while(fa[now]!=goal){
pushdown(fa[fa[now]]),pushdown(fa[now]),pushdown(now);
if(fa[ fa[now] ]==goal)
Rotate(now,ch[ fa[now] ][0]==now);
else{
int pre=fa[now],kind=ch[ fa[pre] ][0]==pre; //左儿子为1,右儿子为0
if(ch[pre][kind]==now){ //两个方向不同
Rotate(now,!kind);
Rotate(now,kind);
}
else{ //两个方向相同
Rotate(pre,kind);
Rotate(now,kind);
}
}
}
if(goal==0) root=now;
pushup(now);
}
4.查询后继
//调用:Splay(x,0),get_min(x)
int get_min(int x){
x=ch[x][1];
pushdown(x);
while(ch[x][0])
x=ch[x][0],pushdown(x);
return x;
}
5.查询前驱
//调用:Splay(x,0),get_max(x)
int get_max(int x){
x=ch[x][0];
pushdown(x);
while(ch[x][1])
x=ch[x][1],pushdown(x);
return x;
}
6.查询第k大
//因为刚开始虚拟构造了两个结点,所以如果要查询第k大的话,应该调用get_kth(root,k+1)
int get_kth(int x,int k){
pushdown(x);
int num=size[ch[x][0]]+1;
if(num==k)
return x;
else if(num>k)
return get_kth(ch[x][0],k);
return get_kth(ch[x][1],k-num);
}
7.删除根节点
//需要保证将要删除的点旋转到了根
void remove(){
int m=get_max(root);
Splay(m,root);
ch[m][1]=ch[root][1];
fa[ch[root][1]]=m;
root=m;
fa[root]=0;
pushup(root);
}
8.插入一个位置或值
//x表示在Splay数中的位置,在这个位置之后插入一个数,k表示插入的数的值
void Insert(int x,int k){
Splay(x,0);
Splay(get_min(x),root);
Newnode(key_value,ch[root][1],k);
pushup(ch[root][1]),pushup(root);
}
int Insert(int now,int k){
while(ch[now][val[now]<k]){
//不重复插入
if(val[now]==k){
Splay(now,0);
cnt[now]++;
pushup(now);
return 0;
}
now=ch[now][val[now]<k];
}
if(val[now]==k){
Splay(now,0);
cnt[now]++;
pushup(now);
return 0;
}
Newnode(ch[now][k>val[now]],now,k);
//将新插入的结点更新至根结点
Splay(ch[now][k>val[now]],0);
return 1;
}
//数不重复插入
int Insert(int now,long long k){
while(ch[now][val[now]<k]){
//不重复插入
if(val[now]==k){
int x=ch[now][1];
if(x==0){
Newnode(ch[now][1],now,k);
Splay(ch[now][1],0);
}
else{
Newnode(ch[now][1],now,k);
ch[ ch[now][1] ][1]=x,fa[x]=ch[now][1];
Splay(x,0);
}
return 0;
}
now=ch[now][val[now]<k];
}
if(val[now]==k){
int x=ch[now][1];
if(x==0){
Newnode(ch[now][1],now,k);
Splay(ch[now][1],0);
}
else{
Newnode(ch[now][1],now,k);
ch[ ch[now][1] ][1]=x,fa[x]=ch[now][1];
Splay(x,0);
}
return 0;
}
Newnode(ch[now][k>val[now]],now,k);
//将新插入的结点更新至根结点
Splay(ch[now][k>val[now]],0);
return 1;
}
9.删除一个数
//如果不需要构造容量池的话,erase函数可以删去
//Delete(x)中的x要删去第x个值
void erase(int x){
if(!x)
return ;
s[++tot2]=x;
erase(ch[x][0]);
erase(ch[x][1]);
}
void Delete(int x){
Splay(get_kth(root,x),0);
Splay(get_kth(root,x+2),root);
erase(key_value),key_value=0,pushup(ch[root][1]),pushup(root);
}
10.区间加和区间翻转
void ADD(int l,int r,int k){
Splay(get_kth(root,l),0);
Splay(get_kth(root,r+2),root);
val[key_value]+=k,addv[key_value]+=k;
pushup(ch[root][1]);
pushup(root);
}
void reverse(int l,int r){
Splay(get_kth(root,l),0);
Splay(get_kth(root,r+2),root);
rev[key_value]^=1;
pushup(ch[root][1]);
pushup(root);
}
10.如果区间是呈现环形的,区间加和翻转
//Hdu 4543
//ADD(x,k,distance)表示x之后的distance个都加上k,x表示在Splay中的位置
void Add(int x,int k,int distance){
Splay(x,0);
//size[ch[x][0]]表示这个x的Rank,size[x]-2表示总共有多少个
if(size[ch[x][0]]+distance-1<=size[x]-2){
int tmp=size[ch[x][0]]+distance-1;
Splay(get_max(root),0);
Splay(get_kth(root,tmp+2),root);
val[key_value]+=k,addv[key_value]+=k;
pushup(ch[root][1]),pushup(root);
}
else{
//把前面的切割掉
int remain=size[ch[x][0]]+distance-1-(size[x]-2);//开头有多少个
Splay(1,0),Splay(get_kth(root,remain+2),root);
int tmp=key_value;
key_value=0,pushup(ch[root][1]),pushup(root);
Splay(get_kth(root,size[root]-1),0),Splay(2,root);
key_value=tmp,fa[key_value]=ch[root][1];
pushup(ch[root][1]),pushup(root);
Splay(x,0);
Splay(get_max(x),0),Splay(2,root);
val[key_value]+=k,addv[key_value]+=k;
}
}
//Reverse(x,distance)表示x之后的distance个翻转,x表示在Splay中的位置
void Reverse(int x,int distance){
Splay(x,0);
//size[ch[x][0]]表示这个x的Rank,size[x]-2表示总共有多少个
if(size[ch[x][0]]+distance-1<=size[x]-2){
int tmp=size[ch[x][0]]+distance-1;
Splay(get_max(root),0);
Pointer=get_kth(root,tmp+1);
Splay(get_kth(root,tmp+2),root);
rev[key_value]^=1;
pushup(ch[root][1]),pushup(root);
}
else{
//把前面的切割掉
int remain=size[ch[x][0]]+distance-1-(size[x]-2);//开头有多少个
Splay(1,0),Splay(get_kth(root,remain+2),root);
Pointer=get_kth(root,remain+1);
int tmp=key_value;
key_value=0,pushup(ch[root][1]),pushup(root);
Splay(get_kth(root,size[root]-1),0),Splay(2,root);
key_value=tmp,fa[key_value]=ch[root][1]; //拼接
pushup(ch[root][1]),pushup(root);
Splay(x,0);
Splay(get_max(x),0),Splay(2,root);
rev[key_value]^=1;
}
}
11.区间循环移位
//l-r之间循环移动k位
void revolve(int l,int r,int k){
if(!k) return ;
Splay(get_kth(root,l),0);
Splay(get_kth(root,r-k+2),root);
//将后者取出
int tmp=key_value;
key_value=0,pushup(ch[root][1]),pushup(root);
Splay(get_kth(root,l+k),0);
Splay(get_kth(root,l+k+1),root);
key_value=tmp;
fa[key_value]=ch[root][1],pushup(ch[root][1]),pushup(root);
}
12.节点的新建
//带回收
void Newnode(int &now,int father,long long k){
if(tot2) now=s[tot2--];
else now=++tot1;
val[now]=k,ch[now][0]=ch[now][1]=0,fa[now]=father;
size[now]=1;
}
//不带回收
void Newnode(int &now,int father,long long k){
now=++tot;
val[now]=k,ch[now][0]=ch[now][1]=0,fa[now]=father;
size[now]=1;
}