为了学 link-cut-tree \text{link-cut-tree} link-cut-tree 才讲的 Splay \text{Splay} Splay,之前已经学过无旋 treap \text{treap} treap了,因为本质上都是对二叉搜索树的优化,理解起来可能会更容易吧,下面就以这一道例题:普通平衡树,来讲解一下 Splay \text{Splay} Splay 的基本操作。
数组定义
- c h [ x ] [ 0 / 1 ] ch[x][0/1] ch[x][0/1],表示 x x x的左儿子或者右儿子。
- v a l [ x ] val[x] val[x],表示 x x x点的键值。
- c n t [ x ] cnt[x] cnt[x],表示 x x x该点的出现次数。
- p a r [ x ] par[x] par[x],表示 x x x的父亲。
- s i z [ x ] siz[x] siz[x],表示 x x x为根的子树的大小。
具体操作
chk 操作
辅助操作,找
x
x
x是它父亲的左儿子还是右儿子。
int chk(int x)
{
return ch[par[x]][1]==x;
}
push_up 操作
辅助操作,用左儿子和右儿子更新一下
s
i
z
siz
siz数组。
void push_up(int x)
{
siz[x]=siz[ch[x][0]]+siz[ch[x][1]]+cnt[x];
}
rotate 操作
Splay
\text{Splay}
Splay的核心操作,旋转
x
x
x,先看一个例子:
其中,一种较为优秀的转法是这样的:
多模拟几次,我们总结一下,假设它的父亲是
y
y
y,
y
y
y的父亲是
z
z
z,我们先找出
x
x
x是
y
y
y的 左儿子
/
/
/右儿子,记它为
k
k
k,我们把
c
h
[
y
]
[
k
]
ch[y][k]
ch[y][k]替换成
c
h
[
x
]
[
!
k
]
ch[x][!k]
ch[x][!k],把
c
h
[
z
]
[
c
h
k
(
y
)
]
ch[z][chk(y)]
ch[z][chk(y)]替换成
x
x
x,
c
h
[
x
]
[
!
k
]
ch[x][!k]
ch[x][!k]替换成
y
y
y,然后再更新
x
x
x和
y
y
y,代码如下:
void rotate(int x)
{
int y=par[x],z=par[y],k=chk(x),w=ch[x][k^1];
ch[y][k]=w;par[w]=y;
ch[z][chk(y)]=x;par[x]=z;
ch[x][k^1]=y;par[y]=x;
push_up(y);push_up(x);
}
Splay 操作
核心操作,把点
x
x
x旋到
y
y
y的子节点处,这里我们使用双选,如果
x
,
y
,
z
x,y,z
x,y,z三点共线,那么我们先旋转
y
y
y,再旋转
x
x
x;否则我们旋转两次
x
x
x,这样旋转出来的树形态更优,代码如下:
void splay(int x,int goal=0)
{
while(par[x]!=goal)
{
int y=par[x],z=par[y];
if(z!=goal)
{
if(chk(x)==chk(y)) rotate(y);
else rotate(x);
}
rotate(x);
}
if(!goal) rt=x;
}
find 操作
把第一个值小于等于
x
x
x点旋转到根,我们先用二叉查找树的方法找到它,然后直接
Splay
\text{Splay}
Splay 它到根。
void find(int x)
{
if(!rt) return ;
int cur=rt;
while(ch[cur][x>val[cur]] && x!=val[cur])
cur=ch[cur][x>val[cur]];
splay(cur);
}
insert 操作
插入
x
x
x这个值,我们先查找这个值,如果找到了,把次数
+
1
+1
+1,否则我们新建一个节点,然后把这个节点旋转到根(随机化树形态)。
void insert(int x)
{
int cur=rt,p=0;
while(cur && val[cur]!=x)
{
p=cur;
cur=ch[cur][x>val[cur]];
}
if(cur) cnt[cur]++;
else
{
cur=++ncnt;
if(p) ch[p][x>val[p]]=cur;
par[cur]=p;ch[cur][0]=ch[cur][1]=0;
cnt[cur]=siz[cur]=1;val[cur]=x;
}
splay(cur);
}
kth 操作
找到第
k
k
k大的值,用普通平衡树的方法,先找左子树的点数够不够,否则看左子树+当前点数够不够,足够则第
k
k
k大就是当前点,否则去找右子树,具体实现如下:
int kth(int k)
{
int cur=rt;
while(1)
{
if(ch[cur][0] && k<=siz[ch[cur][0]])
cur=ch[cur][0];
else if(k>siz[ch[cur][0]]+cnt[cur])
{
k-=siz[ch[cur][0]]+cnt[cur];
cur=ch[cur][1];
}
else return cur;
}
}
pre
/
/
/suc 操作
用
f
i
n
d
(
x
)
find(x)
find(x)把小于等于
x
x
x的值旋转到根,如果根是
x
x
x,那么找 左子树最底下的右儿子
/
/
/右子树最底下的左儿子,否则答案就是根。
int pre(int x)
{
find(x);
if(val[rt]<x) return rt;
int cur=ch[rt][0];
while(ch[cur][1]) cur=ch[cur][1];
return cur;
}
int suc(int x)
{
find(x);
if(val[rt]>x) return rt;
int cur=ch[rt][1];
while(ch[cur][0]) cur=ch[cur][0];
return cur;
}
remove 操作
删除
x
x
x,找到
x
x
x的前驱和后继,把前驱旋转到根,后继旋转到前驱,所以
x
x
x一定是后继的左儿子,且
x
x
x的子树为空,所以可以直接删除
x
x
x,具体实现如下:
void remove(int x)
{
int lst=pre(x),nxt=suc(x);
splay(lst);splay(nxt,lst);
int now=ch[nxt][0];
if(cnt[now]>1)
{
cnt[now]--;
splay(now);
}
else ch[nxt][0]=0;
}
至此 Splay \text{Splay} Splay 的基本操作就讲完了,下面贴个完整代码吧,更多操作还是慢慢学吧。
#include <cstdio>
const int M = 200005;
int read()
{
int x=0,flag=1;char c;
while((c=getchar())<'0' || c>'9') if(c=='-') flag=-1;
while(c>='0' && c<='9') x=(x<<3)+(x<<1)+(c^48),c=getchar();
return x*flag;
}
int n,rt,ncnt,ch[M][2],val[M],cnt[M],par[M],siz[M];
int chk(int x)
{
return ch[par[x]][1]==x;
}
void push_up(int x)
{
siz[x]=siz[ch[x][0]]+siz[ch[x][1]]+cnt[x];
}
void rotate(int x)
{
int y=par[x],z=par[y],k=chk(x),w=ch[x][k^1];
ch[y][k]=w;par[w]=y;
ch[z][chk(y)]=x;par[x]=z;
ch[x][k^1]=y;par[y]=x;
push_up(y);push_up(x);
}
void splay(int x,int goal=0)
{
while(par[x]!=goal)
{
int y=par[x],z=par[y];
if(z!=goal)
{
if(chk(x)==chk(y)) rotate(y);
else rotate(x);
}
rotate(x);
}
if(!goal) rt=x;
}
void find(int x)
{
if(!rt) return ;
int cur=rt;
while(ch[cur][x>val[cur]] && x!=val[cur])
cur=ch[cur][x>val[cur]];
splay(cur);
}
void insert(int x)
{
int cur=rt,p=0;
while(cur && val[cur]!=x)
{
p=cur;
cur=ch[cur][x>val[cur]];
}
if(cur) cnt[cur]++;
else
{
cur=++ncnt;
if(p) ch[p][x>val[p]]=cur;
par[cur]=p;ch[cur][0]=ch[cur][1]=0;
cnt[cur]=siz[cur]=1;val[cur]=x;
}
splay(cur);
}
int kth(int k)
{
int cur=rt;
while(1)
{
if(ch[cur][0] && k<=siz[ch[cur][0]])
cur=ch[cur][0];
else if(k>siz[ch[cur][0]]+cnt[cur])
{
k-=siz[ch[cur][0]]+cnt[cur];
cur=ch[cur][1];
}
else return cur;
}
}
int pre(int x)
{
find(x);
if(val[rt]<x) return rt;
int cur=ch[rt][0];
while(ch[cur][1]) cur=ch[cur][1];
return cur;
}
int suc(int x)
{
find(x);
if(val[rt]>x) return rt;
int cur=ch[rt][1];
while(ch[cur][0]) cur=ch[cur][0];
return cur;
}
void remove(int x)
{
int lst=pre(x),nxt=suc(x);
splay(lst);splay(nxt,lst);
int now=ch[nxt][0];
if(cnt[now]>1)
{
cnt[now]--;
splay(now);
}
else ch[nxt][0]=0;
}
int main()
{
n=read();
insert(0x3f3f3f3f);
insert(0xcfcfcfcf);
for(int i=1;i<=n;i++)
{
int op=read(),x=read();
if(op==1) insert(x);
if(op==2) remove(x);
if(op==3)
{
find(x);
printf("%d\n",siz[ch[rt][0]]);
}
if(op==4) printf("%d\n",val[kth(x+1)]);
if(op==5) printf("%d\n",val[pre(x)]);
if(op==6) printf("%d\n",val[suc(x)]);
}
}
例题
第一道题:序列终结者,
splay
\text{splay}
splay打标记入门题
第二道题:SuperMemo,这道题相对于上一道题多了一个
Revolve
\text{Revolve}
Revolve操作,直接把
[
l
,
r
−
k
]
[l,r-k]
[l,r−k]这个区间拆出来,把它当作点,直接插回原序列,下面贴上我的代码:
#include <cstdio>
#include <iostream>
using namespace std;
#define inf 0x3f3f3f3f
const int M = 200005;
int read()
{
int x=0,flag=1;char c;
while((c=getchar())<'0' || c>'9') if(c=='-') flag=-1;
while(c>='0' && c<='9') x=(x<<3)+(x<<1)+(c^48),c=getchar();
return x*flag;
}
int n,m,rt,ncnt,ch[M][2],val[M],Min[M],par[M],siz[M],fl[M],la[M];
char s[10];
int chk(int x)
{
return ch[par[x]][1]==x;
}
void push_up(int x)//上传
{
if(!x) return ;
siz[x]=siz[ch[x][0]]+siz[ch[x][1]]+1;
Min[x]=min(min(Min[ch[x][0]],Min[ch[x][1]]),val[x]);
}
void flip(int x)//翻转
{
if(!x) return ;
swap(ch[x][0],ch[x][1]);
fl[x]^=1;
}
void add(int x,int c)//加法
{
if(!x) return ;
Min[x]+=c;val[x]+=c;
la[x]+=c;
}
void push_down(int x)//下传标记
{
if(fl[x])
{
flip(ch[x][0]);flip(ch[x][1]);
fl[x]=0;
}
if(la[x])
{
add(ch[x][0],la[x]);add(ch[x][1],la[x]);
la[x]=0;
}
}
void rotate(int x)//旋转
{
int y=par[x],z=par[y],k=chk(x),w=ch[x][k^1];
push_down(y);push_down(x);
ch[y][k]=w;par[w]=y;
ch[z][chk(y)]=x;par[x]=z;
ch[x][k^1]=y;par[y]=x;
push_up(y);push_up(x);
}
void splay(int x,int goal=0)//把x旋转到goal
{
while(par[x]^goal)
{
int y=par[x],z=par[y];
if(z!=goal)
{
if(chk(x)==chk(y)) rotate(y);
else rotate(x);
}
rotate(x);
}
if(!goal) rt=x;
}
int find(int k)//排名为k的点
{
int cur=rt;
while(1)
{
push_down(cur);
if(ch[cur][0] && k<=siz[ch[cur][0]])
cur=ch[cur][0];
else if(k>siz[ch[cur][0]]+1)
{
k-=siz[ch[cur][0]]+1;
cur=ch[cur][1];
}
else return cur;
}
}
void print(int x)
{
if(!x) return ;
push_down(x);
print(ch[x][0]);
printf("%d ",val[x]);
print(ch[x][1]);
}
void ins(int x,int k)//把x插入k位后
{
int a=find(k),b=find(k+1);
splay(a);splay(b,a);
ch[b][0]=x;par[x]=b;
push_up(b);
}
void del(int k)//删除k位
{
int a=find(k-1),b=find(k+1);
splay(a);splay(b,a);
ch[b][0]=0;
push_up(b);
}
int main()
{
n=read();
Min[0]=inf;ncnt=2;
rt=1;siz[1]=siz[2]=1;
ch[1][1]=2;par[2]=1;//加入哨兵
for(int i=1;i<=n;i++)
{
siz[++ncnt]=1;val[ncnt]=Min[ncnt]=read();
ins(ncnt,ncnt-2);
}
m=read();
while(m--)
{
scanf("%s",s);
if(s[0]=='D')
{
del(read()+1);//删除(要考虑哨兵)
continue ;
}
int l=read(),r=read();
if(s[0]=='A')//区间加
{
int a=find(l),b=find(r+2);
splay(a);splay(b,a);
add(ch[b][0],read());
}
if(s[0]=='R' && s[3]=='E')//翻转
{
int a=find(l),b=find(r+2);
splay(a);splay(b,a);
flip(ch[b][0]);
}
if(s[0]=='R' && s[3]=='O')
{
int k=read();
k=(k%(r-l+1)+(r-l+1))%(r-l+1);
//[l,r-k]
int a=find(l),b=find(r-k+2);//拆区间
splay(a);splay(b,a);
int t=ch[b][0];
ch[b][0]=0;par[t]=0;
ins(t,l+k);//重新插入
}
if(s[0]=='I')//插入
{
siz[++ncnt]=1;
val[ncnt]=Min[ncnt]=r;
ins(ncnt,l+1);
}
if(s[0]=='M')//查询最小值
{
int a=find(l),b=find(r+2);
splay(a);splay(b,a);
printf("%d\n",Min[ch[b][0]]);
}
}
}