splay的功能十分强大,但是操作复杂,下面总结一下splay的几种常见操作。
首先明确平衡树的概念:每一个节点的左子树都比它小,右子树都比它大。
变量定义:
ch[x][0/1]:x的左/右儿子
size[x]:x所在子树的大小
fa[x]:x的父亲
清理一个节点:
int clear(int x)
{
ch[x][0]=ch[x][1]=size[x]=0;
}
重新计算一个节点的有关数据:
int update(int x)
{
size[x]=size[ch[x][0]]+size[ch[x][1]]+1;
//等等有关数据
}
接着就是splay的核心旋转操作。
无论是左旋还是右旋,画一画图就可以明白他们是大同小异的,所以这里给出一个通法:
设当前的核心点是x,fa[x]=y,fa[y]=z。
1、让x的与y不同边的儿子认y为父亲,同时y也连向这个点。
2、让y认x为父亲,同时x连向y。
3、让x认z为父亲,若z不为0,则z连向x。
4、先update(y),后update(x)。
int rotate(int x)
{
int i,y,z,kind,k1;
y=fa[x];
if(ch[y][0]==x)kind=0;
else kind=1;
z=fa[y];
if(ch[z][0]==y)k1=0;
else k1=1;
ch[y][kind]=ch[x][kind^1];fa[ch[y][kind]]=y;
fa[y]=x;ch[x][kind^1]=y;
fa[x]=z;
if(z!=0){ch[z][k1]=x;}
update(y);update(x);
}
这是一个单次旋转,而splay就是有多个单次旋转组成。下面给出splay的程序(把sd旋转到td下面):
int splay(int sd,int td)
{
int x=sd,y,z;
while(fa[x]!=td)
{
y=fa[x];z=fa[y];
if(z==td)rotate(x);
else
if(ch[z][0]==y&&ch[y][0]==x||ch[z][1]==y&&ch[y][1]==x){rotate(y);rotate(x);}
else {rotate(x);rotate(x);}
}
if(td==0)root=x;
}
上面旋转要注意的地方是,每一次旋转两次。若x、y、z在同一条直线上,则先旋y再旋x,否则把x旋两遍。
最后不要忘记更新root。
splay最核心的旋转操作已经讲完了,那么它的应用就很简单了。
插入:从上往下找到合适的位置,插入了之后把新点旋到根。
修改:把要修改的点旋到根,修改之后update一下。
删除:1、假设要删除x~y,那么把x-1旋到根,y+1旋到根的右儿子,那么y+1的左子树即为x~y,直接删除(clear)即可。删完之后还是要update一下。
2、删除x还有一种做法就是把x旋转到根,然后删掉它,接着选择它的前驱或者后继做根。
查询一个数的排名/排名为某个值的数:按照平衡树的定义往下找就可以了。
查询一个区间的答案:update时维护区间区间数据。查询x~y时把x-1旋到根,y+1旋到根的右儿子,那么y+1的左子树即为要查询的区间。
求x的前驱或后继:把x旋转到根,然后左子树的最右点就是它的前驱,右子树的最左点就是它的后继。
区间翻转:对于这一操作我们要使用类似线段树的懒惰标记。如果一个点被表上了翻转标记,那么在查询到它时就把它翻转。所谓的翻转就是把它的左右儿子交换。
在翻转x~y时,把x-1旋转到根,把y+1旋转到根的右儿子,然后在根的右儿子的左儿子处打上翻转标记即可。
注意在查询时不要忘记下传标记。另外,翻转标记只需xor 1即可,因为翻转两次等于没翻转。
splay的时间复杂度为O(nlogn+Tlogn),其中n是节点个数,T是操作个数。这个我不会证明。
以JZOJ2744为例,贴一下完整代码:
#include<cstdio>
#include<cstdlib>
#include<cstring>
#include<algorithm>
#define MAXN 200010
using namespace std;
int a[MAXN],ch[MAXN][2],size[MAXN],sum[MAXN],mx[MAXN],lmx[MAXN],rmx[MAXN],n,m,ans,x,v,y,root,ne;
int val[MAXN],fa[MAXN];
char type;
int node(int v)
{
size[ne]=1;
sum[ne]=v;mx[ne]=v;val[ne]=v;
lmx[ne]=max(v,0);rmx[ne]=max(v,0);
}
int get(int num)
{
int i=root,s=0;
while(true)
{
if(s+size[ch[i][0]]+1==num)return i;
if(s+size[ch[i][0]]+1<num){s=s+size[ch[i][0]]+1;i=ch[i][1];}
else {i=ch[i][0];}
}
}
int clear(int x)
{
size[x]=0;ch[x][0]=0;ch[x][1]=0;fa[x]=0;
sum[x]=0;mx[x]=0;lmx[x]=0;rmx[x]=0;val[x]=0;
}
int update(int x)
{
int l=ch[x][0],r=ch[x][1];
size[x]=size[l]+size[r]+1;
sum[x]=sum[l]+sum[r]+val[x];
mx[x]=max(max(mx[l],mx[r]),rmx[l]+val[x]+lmx[r]);
lmx[x]=max(lmx[l],sum[l]+val[x]+lmx[r]);
rmx[x]=max(rmx[r],sum[r]+val[x]+rmx[l]);
}
int rotate(int x)
{
int i,y=fa[x],kind,z,k1;
if(ch[y][0]==x)kind=0;
else kind=1;
z=fa[y];
if(ch[z][0]==y)k1=0;
else k1=1;
ch[y][kind]=ch[x][kind^1];fa[ch[y][kind]]=y;
fa[y]=x;ch[x][kind^1]=y;
fa[x]=z;
if(z!=0){ch[z][k1]=x;}
update(y);update(x);
}
int splay(int sd,int td)
{
int x=sd,y,z;
while(fa[x]!=td)
{
y=fa[x];z=fa[y];
if(z==td)rotate(x);
else
if(ch[z][0]==y&&ch[y][0]==x||ch[z][1]==y&&ch[y][1]==x){rotate(y);rotate(x);}
else {rotate(x);rotate(x);}
}
if(td==0)root=x;
}
int insert(int num,int v)
{
if(root==0)
{
ne++;node(v);
root=ne;
return 0;
}
int i,x1=get(num-1),x0=get(num);
splay(x1,0);splay(x0,x1);
i=ch[root][1];
while(ch[i][0]!=0)i=ch[i][0];
ne++;ch[i][0]=ne;fa[ne]=i;node(v);
splay(ne,0);
}
int build(int x,int v)
{
int i;
if(root==0)
{
ne++;node(v);
root=ne;
return 0;
}
i=root;
while(ch[i][1]!=0)i=ch[i][1];
ne++;ch[i][1]=ne;fa[ne]=i;node(v);
splay(ne,0);
}
int main()
{
int i,j,x1,x0,y1;
scanf("%d\n",&n);
for(i=1;i<n;i++)scanf("%d ",&a[i]);
scanf("%d\n%d\n",&a[n],&m);
memset(mx,-0x7f,sizeof(mx));
n++;
for(i=0;i<=n;i++)build(i,a[i]);
while(m>=1)
{
scanf("%c ",&type);
if(type=='I')
{
scanf("%d %d\n",&x,&v);x++;
insert(x,v);
}
if(type=='D')
{
scanf("%d\n",&x);x++;
x1=get(x-1);x0=get(x+1);
splay(x1,0);splay(x0,x1);
i=ch[root][1];
while(ch[i][0]!=0)i=ch[i][0];
ch[fa[i]][0]=0;clear(i);
update(x0);update(x1);
}
if(type=='R')
{
scanf("%d %d\n",&x,&v);x++;
x0=get(x);
splay(x0,0);
val[root]=v;
update(root);
}
if(type=='Q')
{
scanf("%d %d\n",&x,&y);x++;y++;
x1=get(x-1);y1=get(y+1);
splay(x1,0);splay(y1,x1);
printf("%d\n",mx[ch[ch[root][1]][0]]);
}
m--;
}
}
下面的代码是“排序机械臂”(JZOJ3599)的,里面涉及翻转操作:
#include<cstdio>
#include<cstdlib>
#include<cstring>
#include<algorithm>
#define MAXN 100010
using namespace std;
struct data
{
int v;
int num;
};
data a[MAXN];
int size[MAXN],mi[MAXN],ch[MAXN][2],fa[MAXN],g[MAXN],val[MAXN],mi1[MAXN],val1[MAXN],n,ans,root,ne;
int inf=2000000001;
int get(int num)
{
int i=root,s=0,t;
while(i!=0)
{
if(g[i]==1)
{
t=ch[i][0];ch[i][0]=ch[i][1];ch[i][1]=t;
g[ch[i][0]]^=1;g[ch[i][1]]^=1;
g[i]=0;
}
if(s+size[ch[i][0]]+1==num)return i;
if(s+size[ch[i][0]]+1<num)s=s+size[ch[i][0]]+1,i=ch[i][1];
else i=ch[i][0];
}
}
int node(int v,int v1)
{
size[ne]=1;
val[ne]=v;mi[ne]=v;
val1[ne]=v1;mi1[ne]=v1;
}
int clear(int x)
{
size[x]=0;val[x]=inf;mi[x]=inf;val1[x]=inf;mi1[x]=inf;
ch[x][0]=0;ch[x][1]=0;g[x]=0;
}
int update(int x)
{
size[x]=size[ch[x][0]]+size[ch[x][1]]+1;
mi[x]=min(min(mi[ch[x][0]],mi[ch[x][1]]),val[x]);
mi1[x]=inf;
if(mi[x]==mi[ch[x][0]])mi1[x]=min(mi1[x],mi1[ch[x][0]]);
if(mi[x]==mi[ch[x][1]])mi1[x]=min(mi1[x],mi1[ch[x][1]]);
if(mi[x]==val[x])mi1[x]=min(mi1[x],val1[x]);
}
int rotate(int x)
{
int y=fa[x],z=fa[y],k1,k2;
if(ch[y][0]==x)k1=0;
else k1=1;
if(ch[z][0]==y)k2=0;
else k2=1;
ch[y][k1]=ch[x][k1^1];fa[ch[y][k1]]=y;
fa[y]=x;ch[x][k1^1]=y;
fa[x]=z;if(z!=0)ch[z][k2]=x;
update(y);update(x);
}
int splay(int sd,int td)
{
int x=sd,y,z;
while(fa[x]!=td)
{
y=fa[x];z=fa[y];
if(z==td)rotate(x);
else
{
if(ch[z][0]==y&&ch[y][0]==x||ch[z][1]==y&&ch[y][1]==x){rotate(y);rotate(x);}
else {rotate(x);rotate(x);}
}
}
if(td==0)root=x;
}
int insert(int num,int v,int v1)
{
int i;
if(root==0)
{
ne++;node(v,v1);
root=ne;
return 0;
}
i=root;
while(ch[i][1]!=0)i=ch[i][1];
ne++;fa[ne]=i;ch[i][1]=ne;node(v,v1);
splay(ne,0);
}
int find(int v,int v1)
{
int i=root,s=0,t;
while(i!=0)
{
if(g[i]==1)
{
t=ch[i][0];ch[i][0]=ch[i][1];ch[i][1]=t;
g[ch[i][0]]^=1;g[ch[i][1]]^=1;
g[i]=0;
}
if(val[i]==v&&val1[i]==v1)return s+size[ch[i][0]]+1;
else if(mi[ch[i][0]]<mi[ch[i][1]]||mi[ch[i][0]]==mi[ch[i][1]]&&mi1[ch[i][0]]<mi1[ch[i][1]])i=ch[i][0];
else s=s+size[ch[i][0]]+1,i=ch[i][1];
}
}
int game(const data &a,const data &b)
{
if(a.v<b.v)return 1;
if(a.v>b.v)return 0;
if(a.num<b.num)return 1;
else return 0;
}
int main()
{
int i,j,x1,x2;
scanf("%d",&n);
for(i=1;i<=n;i++)scanf("%d",&a[i].v),a[i].num=i;
memset(mi,0x7f,sizeof(mi));
memset(mi1,0x7f,sizeof(mi1));
a[0].v=inf;a[n+1].v=inf;
insert(0,a[0].v,inf);
for(i=1;i<=n+1;i++)insert(i,a[i].v,i);
sort(a+1,a+n+1,game);
for(i=1;i<=n;i++)
{
ans=find(a[i].v,a[i].num);
printf("%d ",ans-1+i-1);
x1=get(ans-1);x2=get(ans+1);
splay(x1,0);splay(x2,x1);
j=ch[root][1];
while(ch[j][0]!=0)j=ch[j][0];
ch[fa[j]][0]=0;clear(j);
x1=1;x2=get(ans);
splay(x1,0);splay(x2,x1);
g[ch[x2][0]]^=1;
}
}