引入
先看例题:(洛谷 P3369 【模板】普通平衡树)
您需要写一种数据结构,来维护一些数,其中需要提供以下操作:
1.插入 x x x 数
2.删除 x x x 数(若有多个相同的数,因只删除一个)
3.查询 x x x 数的排名(排名定义为比当前数小的数的个数 + 1 +1 +1 )
4.查询排名为 x x x 的数
5.求 x x x 的前驱(前驱定义为小于 x x x,且最大的数)
6.求 x x x 的后继(后继定义为大于 x x x,且最小的数)
显然可以使用BST(二叉搜索树)完成,它的时间复杂度是 O ( l o g n ) O(logn) O(logn),然而在构造数据下,BST会成为一条链,时间复杂度会退化到 O ( n ) O(n) O(n)。
这时,我们就要用到一种“升级版”的BST——平衡树了。
平衡树有很多种,如Splay,Treap,替罪羊树,红黑树,fhq Treap等,这里介绍的是Splay。
Splay
先来说一下BST的基本结构特点:左儿子 < < < 根 < < < 右儿子,中序遍历为有序序列。
这样,我们在操作时,只需要跳左儿子或右儿子就能够找到待操作的节点,在平凡情况下,时间复杂度是 O ( l o g n ) O(logn) O(logn)的。
然而,对于下图(也就是链),它并没有产生应有的优化效果:
对于这样一个BST,将会一直跳左儿子,时间复杂度将会退化至
O
(
n
)
O(n)
O(n)。
而平衡树,通过其独有的旋转操作,使得BST的左右子树尽量平衡。所谓的旋转,就是在新加入一个点时,将节点重新排列,使得新节点成为根节点,并且仍然满足BST的性质。
节点
对于每个节点,我们需要存储以下信息:
struct Splay
{
int fa;//父节点
int ch[2];//子节点
int val;//权值
int cnt;//权值的出现次数
int size;//所在子树的大小
}
t[MAXN];
基操
void maintain(int x)
{
size[x]=size[ch[x][0]]+size[ch[x][1]]+cnt[x];
}//维护子树大小:子树大小为左子树大小、右子树大小与权值数量之和
bool get(int x)
{
return x==ch[fa[x]][1];
}//判断节点是其父亲的左儿子还是右儿子
void clear(int x)
{
ch[x][0]=ch[x][1]=fa[x]=val[x]=cnt[x]=size[x]=0;
}//清除节点的所有信息
旋转
先来了解一下旋转的过程,旋转分为右旋和左旋:
可以发现,旋转并不是简单的图的旋转,而是要改变一些节点之间的关系,使得旋转过后仍然满足BST的性质,而对于上述的链,就可以通过旋转变化为相对平衡的树。
下面以上图为例,分析一下右旋的步骤:
要把 2 2 2 旋到它父节点的位置,为了满足BST性质, 4 4 4 必须要成为 2 2 2 的右儿子。 3 3 3 本来就是 2 2 2 的右儿子,所以在旋转后 3 3 3 必定在右子树中,即在 4 4 4 所在的子树中,而 3 3 3 本来在 4 4 4 的左子树中,所以旋转后必定是 4 4 4 的左儿子。
具体的,我们要将 2 2 2 右旋,我们要先将 4 4 4 的左儿子设为 3 3 3,并将 3 3 3 的父亲设为 4 4 4,将 2 2 2 的右儿子设为 4 4 4,并将 4 4 4 的父亲设为 2 2 2。
那么,右旋的一般步骤是什么呢?
设要旋转的节点为 x x x,它的父亲为 y y y, y y y 的父亲为 z z z。
- 将 y y y 的左儿子设为 x x x 的右儿子
- 若 x x x 的右儿子存在,将 x x x 的右儿子的父亲设为 y y y
- 将 x x x 的右儿子设为 y y y
- 将 y y y 的父亲设为 x x x
- 将 x x x 的父亲设为 z z z
- 若 z z z 存在,将 z z z 的某个子节点(原来 y y y 所在的子节点)设为 x x x
Update 2022.10.11:表达能力不太行,重新组织了一下语言,比如将奇怪的“指向”二字改为了“设为”。
对于一个需要旋转的节点,若它是父节点的左儿子则需要右旋,若它是父节点的右儿子则需要左旋,而左旋的步骤与右旋正好相反,所以可以将右旋和左旋放在一个函数里:
void rotate(int x)
{
int y=t[x].fa,z=t[y].fa,chk=get(x);
t[y].ch[chk]=t[x].ch[chk^1];//1
if(t[x].ch[chk^1])
t[t[x].ch[chk^1]].fa=y;//2
t[x].ch[chk^1]=y;//3
t[y].fa=x;//4
t[x].fa=z;//5
if(z)
t[z].ch[y==t[z].ch[1]]=x;//6
maintain(y);
maintain(x);
}
在Splay中,每加入一个新的节点就需要把它旋转到根。
设当前需旋转的节点为 x x x,节点的旋转可分为以下三种:
- x x x 的父亲是根,这时直接旋转即可
- 父亲和 x x x 的儿子类型相同(即同为左儿子或同为右儿子),这时先旋转父亲,再旋转 x x x
- 父亲和 x x x 的儿子类型不同,这时将 x x x 旋转两次
void splay(int x)
{
for(int f=t[x].fa;f=t[x].fa,f;rotate(x))
if(t[f].fa)
rotate(get(x)==get(f)?f:x);
root=x;
}
splay函数的作用是将节点 x x x 旋转到根,以维护BST的随机性。事实上,单次的splay函数不一定使得BST结构变得完全平衡,甚至有可能使树结构更劣,但是由于在接下来的操作中频繁进行旋转,使得树结构不确定,不会被刻意构造的数据卡掉,均摊复杂度达到 O ( log n ) O(\log n) O(logn)。
插入
旋转是平衡树的核心操作。平衡树本身就是BST,所以它的操作也与一般的BST大同小异,不过需要注意进行旋转。
向平衡树中加入一个值 k k k,要按照BST的性质寻找权值为 k k k 的节点。对节点的操作,可分为以下三种:
- 若当前节点权值为 k k k,将权值数加 1 1 1,维护子树大小,旋转
- 若当前节点权值大于 k k k,则跳到左儿子;若当前节点权值小于 k k k,则跳到右儿子
- 若当前节点不存在,则建立新节点,维护节点信息,维护子树大小,旋转
Update 2022.10.11:这里及下文的旋转是指将当前节点旋转到根。
这样,我们就插入了权值 k k k,且把权值为 k k k 的节点旋转到了根。
void insert(int k)
{
if(!root)//若树为空
{
t[++tot].val=k;
t[tot].cnt++;
root=tot;
maintain(root);
return;
}
int cur=root,f=0;
while(1)
{
if(t[cur].val==k)//1
{
t[cur].cnt++;
maintain(cur);
maintain(f);
splay(cur);
break;
}
f=cur;
cur=t[f].ch[t[f].val<k];//2
if(!cur)//3
{
t[++tot].val=k;
t[tot].cnt++;
t[tot].fa=f;
t[f].ch[t[f].val<k]=tot;
maintain(tot);
maintain(f);
splay(tot);
break;
}
}
}
插入操作是比较复杂的操作,对照代码好好理解,麻烦的是要将节点信息全部更新。
查询排名
给出一个值 x x x,求出它的排名,排名定义为比当前数小的数的个数 + 1 +1 +1。
我们可以不断沿着树边向下寻找,可分为三种情况:
- 当前节点值大于 x x x,向左子树走
- 当前节点值为 x x x,累加左子树 s i z e size size,旋转,返回答案
- 当前节点值小于 x x x,累加左子树 s i z e size size,累加当前节点 c n t cnt cnt,向右子树走
可能不大好理解,具体见注释:
int rnk(int x)
{
int res=0,cur=root;
while(1)
{
if(x<t[cur].val)//向左子树走,而不用累加答案,因为比x小的都在左子树
cur=t[cur].ch[0];
else
{
res+=t[t[cur].ch[0]].size;//累加左子树的size,因为左子树上的权值都小于x
if(x==t[cur].val)//如果权值与x相等
{
splay(cur);//旋转
return res+1;//“排名定义为比当前数小的数的个数+1”
}
res+=t[cur].cnt;//累加当前节点size,因为当前节点权值小于x
cur=t[cur].ch[1];//右子树
}
}
}
可能有人会说了,为什么找到后要旋转呢?这里的旋转有什么用呢?
旋转是平衡树用来维护BST的相对平衡的操作,所以说,每一个旋转都是在尽量的减少运行时间,旋转操作多多益善。
而旋转除了维护平衡,它还有什么作用呢?可以将节点旋转到根。
在下面的部分我们会看到,会利用到这个功能。所以还是尽量去多写旋转,多了没问题,少了可能会挂。
查询数值
给定一个数 k k k,查询排名为 k k k 的数。
分为两种情况:
- 若 k k k 小于等于左子树的 s i z e size size,则说明排名为 k k k 的值在左子树中,向左子树走
- 否则,将 k k k 减去左子树的 s i z e size size 和当前节点的 c n t cnt cnt,使得 k k k 等于在右子树中的排名。然而若 k k k 小于等于 0 0 0,说明已经找到,进行旋转,返回当前节点权值。
int kth(int k)
{
int cur=root;
while(1)
{
if(t[cur].ch[0]&&k<=t[t[cur].ch[0]].size)//左子树存在且排名为k的值在左子树
cur=t[cur].ch[0];
else
{
k-=t[t[cur].ch[0]].size+t[cur].cnt;//将k改为在右子树的排名
if(k<=0)//如果排名小于等于0,说明已经找到,直接返回
{
splay(cur);
return t[cur].val;
}
cur=t[cur].ch[1];
}
}
}
前驱和后继
查找前驱和后继的方法极为相似,所以我以前驱为例讲解。
首先,为了便于查找 x x x 的前驱,我们插入一个值为 x x x 的节点,查找完之后再删掉(删除会最后讲)。
这时,就要用到旋转的作用了!我们在插入一个节点的同时,会将它旋转到根。所以,我们只需要查找根的前驱。
还记得前驱的定义是什么吗?
前驱定义为小于 x x x,且最大的数
首先,我们找小于 x x x 的值,显然在左子树。然后,我们找最大的值,显然,要一直向右子树走,走到叶子结点,叶子结点的值就是答案。
这样,我们就得到了求出前驱的步骤:先向左走一下,再一直向右走。
而后继就正好与前驱相反:先向右走一下,再一直向左走。
int pre()
{
int cur=t[root].ch[0];//向左
if(!cur)//如果已经到叶子结点
return cur;
while(t[cur].ch[1])//向右
cur=t[cur].ch[1];
splay(cur);//旋转
return cur;
}
int nxt()
{
int cur=t[root].ch[1];//向右
if(!cur)
return cur;
while(t[cur].ch[0])//向左
cur=t[cur].ch[0];
splay(cur);//旋转
return cur;
}
调用:
insert(x);
printf("%d\n",t[pre()].val);
del(x);//前驱
insert(x);
printf("%d\n",t[nxt()].val);
del(x);//后继
删除
删除算是比较难的操作了。对于删除操作,我们要先将待删除的点旋转到根,而给出的是待删除的值 x x x。所以我们只需调用一次 r a n k rank rank 函数。
旋转到根以后,我们就只需删除根节点即可。分为以下五种情况:
- 根节点的 c n t cnt cnt 大于 1 1 1,将 c n t cnt cnt 减 1 1 1 即可
- 根节点没有左儿子和右儿子,直接 c l e a r clear clear 掉,根指向 0 0 0
- 根节点没有左儿子,只有右儿子,将根设为右儿子,新根的父亲设为 0 0 0, c l e a r clear clear 掉旧根
- 根节点没有右儿子,只有左儿子,将根设为左儿子,新根的父亲设为 0 0 0, c l e a r clear clear 掉旧根
- 根节点有左右儿子,这时我们找到 x x x 的前驱来做根节点。我们通过 p r e pre pre 函数将 x x x 的前驱旋转到根,然后将旧根的右儿子的父亲设为新根,将新根的右儿子设为旧根的右儿子,然后 c l e a r clear clear 掉旧根,维护新根的节点信息。
void del(int k)
{
rnk(k);
if(t[root].cnt>1)//1
{
t[root].cnt--;
maintain(root);
return;
}
if(!t[root].ch[0]&&!t[root].ch[1])//2
{
clear(root);
root=0;
return;
}
if(!t[root].ch[0])//3
{
int cur=root;
root=t[root].ch[1];
t[root].fa=0;
clear(cur);
return;
}
if(!t[root].ch[1])//4
{
int cur=root;
root=t[root].ch[0];
t[root].fa=0;
clear(cur);
return;
}
int cur=root;//5
int x=pre();
t[t[cur].ch[1]].fa=root;
t[root].ch[1]=t[cur].ch[1];
clear(cur);
maintain(root);
}
总结
那么,Splay的知识就讲完了,在写平衡树的时候要注意以下几点:
- 尽量多旋转,就算不额外旋转也一定不要漏掉必要的旋转
- 求前驱和后缀一定别忘了删除新插入的点
- 删除别忘了先 r n k rnk rnk,有左右儿子别忘了 p r e pre pre
- 在所有操作中都别忘了维护节点信息
- 写挂了建议重写
代码
#include<iostream>
#include<cstdio>
#define MAXN 100010
using namespace std;
int root,tot;
struct Splay
{
int fa;
int ch[2];
int val;
int cnt;
int size;
}
t[MAXN];
void maintain(int x)
{
t[x].size=t[t[x].ch[0]].size+t[t[x].ch[1]].size+t[x].cnt;
}
bool get(int x)
{
return x==t[t[x].fa].ch[1];
}
void clear(int x)
{
t[x].ch[0]=t[x].ch[1]=t[x].fa=t[x].val=t[x].cnt=t[x].size=0;
}
void rotate(int x)
{
int y=t[x].fa,z=t[y].fa,chk=get(x);
t[y].ch[chk]=t[x].ch[chk^1];
if(t[x].ch[chk^1])
t[t[x].ch[chk^1]].fa=y;
t[x].ch[chk^1]=y;
t[y].fa=x;
t[x].fa=z;
if(z)
t[z].ch[y==t[z].ch[1]]=x;
maintain(y);
maintain(x);
}
void splay(int x)
{
for(int f=t[x].fa;f=t[x].fa,f;rotate(x))
if(t[f].fa)
rotate(get(x)==get(f)?f:x);
root=x;
}
void insert(int k)
{
if(!root)
{
t[++tot].val=k;
t[tot].cnt++;
root=tot;
maintain(root);
return;
}
int cur=root,f=0;
while(1)
{
if(t[cur].val==k)
{
t[cur].cnt++;
maintain(cur);
maintain(f);
splay(cur);
break;
}
f=cur;
cur=t[f].ch[t[f].val<k];
if(!cur)
{
t[++tot].val=k;
t[tot].cnt++;
t[tot].fa=f;
t[f].ch[t[f].val<k]=tot;
maintain(tot);
maintain(f);
splay(tot);
break;
}
}
}
int rnk(int k)
{
int res=0,cur=root;
while(1)
{
if(k<t[cur].val)
cur=t[cur].ch[0];
else
{
res+=t[t[cur].ch[0]].size;
if(k==t[cur].val)
{
splay(cur);
return res+1;
}
res+=t[cur].cnt;
cur=t[cur].ch[1];
}
}
}
int kth(int k)
{
int cur=root;
while(1)
{
if(t[cur].ch[0]&&k<=t[t[cur].ch[0]].size)
cur=t[cur].ch[0];
else
{
k-=t[t[cur].ch[0]].size+t[cur].cnt;
if(k<=0)
{
splay(cur);
return t[cur].val;
}
cur=t[cur].ch[1];
}
}
}
int pre()
{
int cur=t[root].ch[0];
if(!cur)
return cur;
while(t[cur].ch[1])
cur=t[cur].ch[1];
splay(cur);
return cur;
}
int nxt()
{
int cur=t[root].ch[1];
if(!cur)
return cur;
while(t[cur].ch[0])
cur=t[cur].ch[0];
splay(cur);
return cur;
}
void del(int k)
{
rnk(k);
if(t[root].cnt>1)
{
t[root].cnt--;
maintain(root);
return;
}
if(!t[root].ch[0]&&!t[root].ch[1])
{
clear(root);
root=0;
return;
}
if(!t[root].ch[0])
{
int cur=root;
root=t[root].ch[1];
t[root].fa=0;
clear(cur);
return;
}
if(!t[root].ch[1])
{
int cur=root;
root=t[root].ch[0];
t[root].fa=0;
clear(cur);
return;
}
int cur=root;
int x=pre();
t[t[cur].ch[1]].fa=root;
t[root].ch[1]=t[cur].ch[1];
clear(cur);
maintain(root);
}
int n,op,x;
int main()
{
scanf("%d",&n);
while(n--)
{
scanf("%d%d",&op,&x);
if(op==1)
insert(x);
else if(op==2)
del(x);
else if(op==3)
printf("%d\n",rnk(x));
else if(op==4)
printf("%d\n",kth(x));
else if(op==5)
{
insert(x);
printf("%d\n",t[pre()].val);
del(x);
}
else
{
insert(x);
printf("%d\n",t[nxt()].val);
del(x);
}
}
return 0;
}
撒花!!!