Splay树
首先介绍BST,也就是所有平衡树的开始,他的China名字是二叉查找树。
BST性质简介
给定一棵二叉树,每一个节点有一个权值,命名“关键码”,至于为什么叫这个名字,我也不知道。BST的性质就是,对于树中任何一个节点,都满足以下性质:
1.这个节点的关键码不小于它的左子树上任意一个节点的关键码
2.这个节点的关键码不大于它右子树上的任意一个关键码
然后我们就可以发现这棵树的中序遍历,就是一个关键码单调递增的节点序列。
Splay的背景
什么是Splay,它就是一种可以旋转的平衡树。它可以解决BST树这棵树一个极端情况,也就是退化情况。如下图所示。
我们发现上面这个图是一条链,这种恐怖的数据让时间从O(log(n))退化到O(n)。
Splay思路
这是一棵特殊的BST树,或者说平衡树基本都是改变树结构样式,但是却不改变最后得出的排序序列。
这张图片大致意思如下:正方形部分表示一棵子树,然后圆形表示节点(当然了其实你也可以都看成节点,可能更好理解)
对于这样一棵树,我们可以做一些特殊的操作,来让它变换树的形态结构,但是最后的答案是正确的。平衡树的精髓就是这个,就是改变树的形态结构,但是不改变中序遍历,也就是答案数组。
重点来了,以下为splay精华所在之处,一定要全神贯注地看,并且手中拿着笔和草稿纸,一步步跟着我一起做,那么我保证你看完一遍就可以懂。
现在我们的目标只有一个:x节点往上爬,爬到它原本的父节点y,然后让y节点下降。
首先思考BST的性质,那就是右子树上的点,统统都比父节点大对不对,现在我们的x节点是父节点y的左节点,也就是比y节点小。那么我们为了不改变答案的顺序,我们可以让y节点成为x节点的右儿子,也就是y节点仍然大于我们的x节点。
那么这样做当然是没有问题的,那么我们现在又有一个问题,x节点的右子树原来是有子树B的,那么如果说现在y节点以及它的右子树(没有左子树,因为曾经x节点是它的左子树),放到了x节点的右子树上,那么岂不是多了些什么吗?
我们知道x节点的右子树必然是大于x节点的,然后y节点必然是大于x节点的右子树和x节点的,因为x节点和它的右子树都是在y节点的左子树,都比它小。
既然如此的话,我们为什么不把x节点原来的右子树放在y的左子树上面呢?这样的话,我们就巧妙地避开了冲突,达成了x节点上移,y节点下移。
移动后的图片
这就是一种情况,但是我们不能局限于一种情况,我们要找到通解。以下为通解
若节点x为y节点的位置z(z为0,则为左节点,1则为右节点)
1.那么y节点就放到x节点的z^1的位置(也就是,x节点为y节点的右子树,那么y节点就放到左子树,x节点为y节点的左子树,那么y节点就放到右子树)
2.如果x节点的z ^ 1的位置上已经有节点或者一棵子树,那么我们就将原来x节点z ^ 1位置上的子树,放到y节点的位置z上面。
3.移动完毕
yyb大佬的代码,个人认为最精简的代码&最适合理解的代码
t是树上节点的结构体,ch数组表示左右儿子,ch[0]是左儿子,ch[1]是右儿子,ff是父节点
struct splay_tree //定义splay_tree
{
int ff,cnt,ch[2],val,size;
}
void update(int x)
{
t[x].size=t[t[x].ch[0]].size+t[t[x].ch[1]].size+t[x].cnt;//左子树+右子树+本身多少个
}
void rotate(int x)//x是要旋转的节点
{
int y=t[x].ff; //X的父亲
int z=t[y].ff; //X的祖父
int k=t[y].ch[1]==x;//X是Y的哪一个儿子 0是左儿子 1是右儿子
t[z].ch[t[z].ch[1]==y]=x;//Z的原来的Y位置变为X
t[x].ff=z;//x的父亲变为Z
t[y].ch[k]=t[x].ch[k^1];//X的与X原来在Y的相对的那个儿子变成Y的儿子
t[t[x].ch[k^1]].ff=y; //更新父节点
t[x].ch[k^1]=y;
t[y].ff=x;//更新父节点
update(y);
update(x);
}
如果你已经读到了这里,那么恭喜你,现在的你成功完成了splay的大部分了,但是你发现这条链表结构还是会卡死你,不要气恼,因为你只需要再来一个splay函数就好了,你已经完成了85%,接下来很容易。
如果说x,y,z这三个节点共线,也就是x和它的父亲节点和它的祖先节点在同一条线段上的话,那么我们就需要来一些特殊处理了,其实就是很容易的一些操作:
下面就是三点共线的一张图片
这张图片里面的最长链是Z->Y->X->A
如果我们一直都是x旋转的话,那么就会得到下面这张图片。
而一直旋转x的最长链是X->Z->Y->B
我们发现旋转和不旋转没有任何区别,算法失败了,不过不用害怕,其实我们还有办法。
1.如果当前处于共线状态的话,那么先旋转y再旋转x,这样可以强行让他们不共线,然后平衡这棵树
2.如果当前不是共线状态的话,那么只要旋转x即可
当你看懂这个以后恭喜你,你已经成功学会splay的双旋操作了。
splay操作
这里将0可以理解为根节点的父亲,主要是为了防止各种操作时越界。
void splay(int x,int goal) //将x旋转为goal的儿子,如果goal是0则旋转到根
{
while(t[x].ff!=goal)
{
int y=t[x].ff,z=t[y].ff;//父节点和祖父节点
if(z!=goal) //如果Y不是根节点,则分上面两类来旋转
(t[z].ch[0]==y)^(t[y].ch[0]==x)?rotate(x):rotate(y);//判断共线还是不共线
rotate(x);//无论怎么样最后的一个操作都是旋转x
}
if(goal==0) root=x;//如果goal是0,则将根节点更新为x
}
查找find操作
从根节点开始,左侧都比他小,右侧都比他大,所以只需要相应的往左/右递归如果当前位置的val已经是要查找的数,那么直接把他splay到根节点,方便接下来的操作。
void find(int x) //查找x的位置
{
int u=root;
if(!u) return ;//树空
while(t[u].ch[x>t[u].val]&&x!=t[u].val) //当存在儿子并且当前位置的值不等于x
u=t[u].ch[x>t[u].val];//跳转到儿子
splay(u,0);//把当前位置旋转到根节点
}
Insert操作
往Splay中插入一个数
类似于find操作,只是如果是已经存在的数,就可以直接在查找到的节点进行计数,如果不存在,在递归的过程中,会找到他的父节点的位置,然后在底下新建节点即可。
void insert(int x)//插入x
{
int u=root,ff=0; //当前位置u,u的父亲节点ff
while(u&&t[u].val!=x) //当u存在且没有移动到正确的位置
{
ff=u;
u=t[u].ch[x>t[u].val]; //大于当前位置则向右找,否则向左找
}
if(u) t[u].cnt++; //增加一个数
else//不存在这个数字,要新建一个节点来存在
{
u=++tot;//新节点的位置
if(ff) //如果父节点非根
t[ff].ch[x>t[ff].val]=u;
t[u].ch[0]=t[u].ch[1]=0;//不存在儿子
t[tot].ff=ff;//父节点
t[tot].val=x;//值
t[tot].cnt=1;
t[tot].size=1;//大小
}
splay(u,0);
}
前驱/后继操作Next
首先就要执行find操作,把要查找的数弄到根节点上。
然后以前驱为例,先确定前驱比他小,所以在左子树上,然后他的前驱是左子树上的最大值。
int Next(int x,int f)//查找x的前驱(0)或者后继(1)
{
find(x);
int u=root;
if(t[u].val>x&&f) return u;//如果当前节点的值大于x并且要查找的是后继
if(t[u].val<x&&!f) return u;//如果当前节点的值小于x并且要查找的是前驱
u=t[u].ch[f]; //查找后继的话在右儿子上找,前驱在左儿子上找
while(t[u].ch[f^1]) u=t[u].ch[f^1];
return u; //返回位置
}
删除操作
现在就很简单啦,首先找到这个数的前驱,把他Splay到根节点,然后找到这个数的后继,把他旋转到前驱的底下。比前驱大的数是后继,在右子树比后继小的且比前驱大的数有且仅有当前数。在后继的左子树上面。因此把当前根节点的右儿子的左儿子删掉就可以了。
void Delete(int x)//删除x
{
int last=Next(x,0);//找x的前驱
int next=Next(x,1);//找x的后继
splay(last,0); //把前驱旋转到根节点
splay(next,last);//后继旋转到根节点下面
int del=t[next].ch[0];//后继的左儿子
if(t[del].cnt>1) //如果超过一个
{
t[del].cnt--;//直接减少一个
splay(del,0);//旋转
}
else t[next].ch[0]=0;//这个节点直接丢掉
}
第k大的数
从当前根节点开始,检查左子树的大小。因为所有比当前位置小的数都在左侧。如果左侧的数的个数多余k,则证明第K大在左子树中。否则向右子树寻找。
int kth(int x) //查找排名x的数
{
int u=root;//当前根节点
if(t[u].size<x) //如果当前树上没有那么多数
return 0;//不存在
while(1)
{
int y=t[u].ch[0];//左儿子
if(x>t[y].size()+t[u].cnt)
//如果排名比左儿子的大小和当前节点的数量要大
{
x-=t[y].size+t[u].cnt;
u=t[u].ch[1]; //那么当前排名的数一定会在右儿子上寻找
}
else//否则的话在当前节点或者左儿子上查找
{
if(t[y].size>=x) //左儿子的节点数足够
u=y;
else return t[u].val;//否则就是在当前节点上
}
}
}
完整代码
#include <bits/stdc++.h>
using namespace std;
const int N=201000;
struct splay_tree
{
int ff,cnt,ch[2],val,size;
} t[N];
int root,tot;
void update(int x)
{
t[x].size=t[t[x].ch[0]].size+t[t[x].ch[1]].size+t[x].cnt;
}
void rotate(int x)
{
int y=t[x].ff;
int z=t[y].ff;
int k=(t[y].ch[1]==x);
t[z].ch[(t[z].ch[1]==y)]=x;
t[x].ff=z;
t[y].ch[k]=t[x].ch[k^1];
t[t[x].ch[k^1]].ff=y;
t[x].ch[k^1]=y;
t[y].ff=x;
update(y);update(x);
}
void splay(int x,int s)
{
while(t[x].ff!=s)
{
int y=t[x].ff,z=t[y].ff;
if (z!=s)
(t[z].ch[0]==y)^(t[y].ch[0]==x)?rotate(x):rotate(y);
rotate(x);
}
if (s==0)
root=x;
}
void find(int x)
{
int u=root;
if (!u)
return ;
while(t[u].ch[x>t[u].val] && x!=t[u].val)
u=t[u].ch[x>t[u].val];
splay(u,0);
}
void insert(int x)
{
int u=root,ff=0;
while(u && t[u].val!=x)
{
ff=u;
u=t[u].ch[x>t[u].val];
}
if (u)
t[u].cnt++;
else
{
u=++tot;
if (ff)
t[ff].ch[x>t[ff].val]=u;
t[u].ch[0]=t[u].ch[1]=0;
t[tot].ff=ff;
t[tot].val=x;
t[tot].cnt=1;
t[tot].size=1;
}
splay(u,0);
}
int Next(int x,int f)
{
find(x);
int u=root;
if (t[u].val>x && f)
return u;
if (t[u].val<x && !f)
return u;
u=t[u].ch[f];
while(t[u].ch[f^1])
u=t[u].ch[f^1];
return u;
}
void Delete(int x)
{
int last=Next(x,0);
int Net=Next(x,1);
splay(last,0);
splay(Net,last);
int del=t[Net].ch[0];
if (t[del].cnt>1)
{
t[del].cnt--;
splay(del,0);
}
else
t[Net].ch[0]=0;
}
int kth(int x)
{
int u=root;
while(t[u].size<x)
return 0;
while(1)
{
int y=t[u].ch[0];
if (x>t[y].size+t[u].cnt)
{
x-=t[y].size+t[u].cnt;
u=t[u].ch[1];
}
else if (t[y].size>=x)
u=y;
else
return t[u].val;
}
}
int main()
{
int n;
scanf("%d",&n);
insert(1e9);
insert(-1e9);
while(n--)
{
int opt,x;
scanf("%d%d",&opt,&x);
if (opt==1)
insert(x);
if (opt==2)
Delete(x);
if (opt==3)
{
find(x);
printf("%d\n",t[t[root].ch[0]].size);
}
if (opt==4)
printf("%d\n",kth(x+1));
if (opt==5)
printf("%d\n",t[Next(x,0)].val);
if (opt==6)
printf("%d\n",t[Next(x,1)].val);
}
return 0;
}
/*
插入数值x。
删除数值x(若有多个相同的数,应只删除一个)。
查询数值x的排名(若有多个相同的数,应输出最小的排名)。
查询排名为x的数值。
求数值x的前驱(前驱定义为小于x的最大的数)。
求数值x的后继(后继定义为大于x的最小的数)。
*/