前言
Splay是一种维护平衡二叉树的算法。虽然它常数大,而且比较难打,但Splay十分方便,而且LCT需要用到。
约定
c
n
t
i
cnt_i
cnti:节点
i
i
i的个数
v
a
l
i
val_i
vali:节点
i
i
i的权值
s
i
z
i
siz_i
sizi:节点
i
i
i的子树大小
c
h
i
,
0
/
1
ch_{i,0/1}
chi,0/1:节点
i
i
i的左右儿子
f
a
i
fa_i
fai:节点
i
i
i的父亲
r
o
o
t
root
root:当前的根节点
t
o
t
tot
tot:当前的节点数量
g
t
(
x
)
gt(x)
gt(x):返回
x
x
x是左儿子还是右儿子
p
u
s
h
u
p
(
x
)
pushup(x)
pushup(x):更新当前子树大小
int gt(int x){
return ch[fa[x]][1]==x;
}
void pt(int x){
siz[x]=cnt[x]+siz[ch[x][0]]+siz[ch[x][1]];
}
基本操作
旋转操作
- 原来 y y y是 z z z的哪个儿子, x x x就作为 z z z的哪个儿子
- 原来 x x x是 y y y的哪个儿子, x x x的对应儿子的兄弟就作为 y y y的哪个儿子
- 原来 x x x是 y y y的哪个儿子, y y y就作为 x x x的对应儿子的兄弟
void rot(int x){
int y=fa[x],z=fa[y],k=gt(x);
ch[z][gt(y)]=x,fa[x]=z;
ch[y][k]=ch[x][!k],fa[ch[y][k]]=y;
ch[x][!k]=y,fa[y]=x;
pt(y);pt(x);
}
伸展操作
s p l a y ( x , g ) splay(x,g) splay(x,g),表示将 x x x旋转到 g g g下面。
我们可以一直 r o t rot rot,但如果 x x x的父亲不是 g g g且 x x x和 x x x的父亲是同一边的儿子,则可以旋转父亲。先旋转父亲可以减少深度。
void splay(int x,int g=0){
for(int y;fa[x]!=g;rot(x)){
y=fa[x];
if(fa[y]!=g) rot((gt(x)==gt(y))?y:x);
}
if(!g) root=x;
}
普通操作
find操作
找到值最接近 x x x的点,并伸展到根。
void find(int x){
if(!root) return;
int u=root;
while(ch[u][x>v[u]]&&x!=v[u]) u=ch[u][x>v[u]];
splay(u);
}
insert操作
插入值为 x x x的点,需进行一下操作
- 找到插入点的位置
- 如果存在值为 x x x的点,则加对应的 c n t cnt cnt
- 否则新加一个点
- 把该节点伸展到根
void insert(int x){
int u=root,fu=0;
while(u&&v[u]!=x){
fu=u,u=ch[u][x>v[u]];
}
if(u) ++cnt[u];
else{
u=++tot;
if(fu) ch[fu][x>v[fu]]=u;
fa[u]=fu;v[u]=x;
cnt[u]=siz[u]=1;
}
splay(u);
}
前驱和后继
求 x x x的前驱和后继
先 f i n d ( x ) find(x) find(x),那么前驱就是左子树中最大的一个,后继就是右子树中最小的一个。
int nxt(int x,int f){
find(x);
int u=root;
if(v[u]>x&&f) return u;
if(v[u]<x&&!f) return u;
u=ch[u][f];
while(ch[u][!f]) u=ch[u][!f];
return u;
}
delete操作
删除值为 x x x的点
首先找到 x x x的前驱 s u c suc suc和 x x x的后继 p r e pre pre,然后
- s p l a y ( p r e ) splay(pre) splay(pre)
- s p l a y ( s u c , p r e ) splay(suc,pre) splay(suc,pre)
然后 s u c suc suc的左子树就是要删除的点,删除即可。
void dele(int x){
int lt=nxt(x,0),nt=nxt(x,1);
splay(lt);splay(nt,lt);
int tx=ch[nt][0];
if(cnt[tx]>1) --cnt[tx],splay(tx);
else ch[nt][0]=0;
}
kth操作
找到排名为 k k k的节点的权值
int kth(int k){
int u=root,sn;
for(;;){
sn=ch[u][0];
if(k>siz[sn]+cnt[u]) k-=siz[sn]+cnt[u],u=ch[u][1];
else if(siz[sn]>=k) u=sn;
else return v[u];
}
}
例题
对于操作1,2,4,5,6,可以用上述操作解决即可。对于操作3,可以 f i n d ( x ) find(x) find(x)将其置为根,然后 x x x的排名就是它左子树的节点个数 + 1 +1 +1。
注意为了防止 s p l a y splay splay出锅,要在加上两个节点 + ∞ +\infty +∞和 − ∞ -\infty −∞。注意这两个节点对操作的影响。
#include<iostream>
#include<cstdio>
#define N 500000
using namespace std;
int t,rt,tot,cnt[N],v[N],siz[N],fa[N],ch[N][2];
int gt(int x){
return ch[fa[x]][1]==x;
}//Return 0 or 1 means x is the left or right son
void pt(int x){
siz[x]=cnt[x]+siz[ch[x][0]]+siz[ch[x][1]];
}//Update the x
void rot(int x){
int y=fa[x],z=fa[y],k=gt(x);
ch[z][gt(y)]=x,fa[x]=z;
ch[y][k]=ch[x][!k],fa[ch[y][k]]=y;
ch[x][!k]=y,fa[y]=x;
pt(y);pt(x);
}//Rotate
void splay(int x,int g=0){
for(int y;fa[x]!=g;rot(x)){
y=fa[x];
if(fa[y]!=g) rot((gt(x)==gt(y))?y:x);
}
if(!g) rt=x;
}//Put the x under the g
void find(int x){
if(!rt) return;
int u=rt;
while(ch[u][x>v[u]]&&x!=v[u]) u=ch[u][x>v[u]];
splay(u);
}//Find the closest node and put it under the root
void insert(int x){
int u=rt,fu=0;
while(u&&v[u]!=x){
fu=u,u=ch[u][x>v[u]];
}
if(u) ++cnt[u];
else{
u=++tot;
if(fu) ch[fu][x>v[fu]]=u;
fa[u]=fu;v[u]=x;
cnt[u]=siz[u]=1;
}
splay(u);
}//Insert the x
//1.Find the root which should be inserted
//2.If there is a node as same as it,plus its cnt
//3.Else plus a node
//4.Make the new node be the root
int nxt(int x,int f){
find(x);
int u=rt;
if(v[u]>x&&f) return u;
if(v[u]<x&&!f) return u;
u=ch[u][f];
while(ch[u][!f]) u=ch[u][!f];
return u;
}//Find the suc or the pre of the x
//After finding x,then
//The pre is the biggest one in left tree
//The suc is the smallest one in right tree
void dele(int x){
int lt=nxt(x,0),nt=nxt(x,1);
splay(lt);splay(nt,lt);
int tx=ch[nt][0];
if(cnt[tx]>1) --cnt[tx],splay(tx);
else ch[nt][0]=0;
}//Find the pre and the suc of the x
//Splay(pre),splay(suc,pre)
//Then delete the left son of the suc
int kth(int k){
int u=rt,sn;
for(;;){
sn=ch[u][0];
if(k>siz[sn]+cnt[u]) k-=siz[sn]+cnt[u],u=ch[u][1];
else if(siz[sn]>=k) u=sn;
else return v[u];
}
}
int main()
{
insert(2147483647);
insert(-2147483647);
scanf("%d",&t);
while(t--){
int op,x;
scanf("%d%d",&op,&x);
if(op==1) insert(x);
else if(op==2) dele(x);
else if(op==3) find(x),printf("%d\n",siz[ch[rt][0]]+(v[rt]<x)*cnt[rt]);
else if(op==4) printf("%d\n",kth(x+1));
else if(op==5) printf("%d\n",v[nxt(x,0)]);
else printf("%d\n",v[nxt(x,1)]);
}
return 0;
}