3224: Tyvj 1728 普通平衡树
Time Limit: 10 Sec
Memory Limit: 128 MB
Submit: 12714 Solved: 5427
[Submit][Status][Discuss] Description
您需要写一种数据结构(可参考题目标题),来维护一些数,其中需要提供以下操作:
1. 插入x数
2. 删除x数(若有多个相同的数,因只删除一个)
3. 查询x数的排名(若有多个相同的数,因输出最小的排名)
4. 查询排名为x的数
5. 求x的前驱(前驱定义为小于x,且最大的数)
6. 求x的后继(后继定义为大于x,且最小的数)
Input
第一行为n,表示操作的个数,下面n行每行有两个数opt和x,opt表示操作的序号(1<=opt<=6)
Output
对于操作3,4,5,6每行输出一个数,表示对应答案
Sample Input
10
1 106465
4 1
1 317721
1 460929
1 644985
1 84185
1 89851
6 81968
1 492737
5 493598
Sample Output
106465
84185
492737
HINT
1.n的数据范围:n<=100000
2.每个数的数据范围:[-2e9,2e9]
这是一道模板题,调了一上午+半个下午,调出无数错误最后AC,懂得了splay树的基础用法,理解了双旋的含义
splay(int x,int ff)表示将编号为x的节点转到ff的儿子,ff=0表示转到根。rotate(x)表示将x上旋一个,
splay核心代码:
void splay(int x,int ff) {
for(int F;(F=fa[x])!=ff;rotate(x))
{
if(fa[F])
{
if(!((gets(F))^(gets(x)))) rotate(F);
else rotate(x);
}
}
if(!ff) root=x; }
我们基于以下规则进行操作:
1、如果x本来就是ff的儿子直接跳出循环,完毕(因为他会先判条件再执行改变操作,所以这种情况他会直接跳出)
2、如果他的爸爸的爸爸是ff的儿子,就直接上旋一次即可。
若非以上两种 就先转一次,然后不停执行3,直到是以上以上两种。
3、他的爸爸的爸爸并不是ff的儿子那就分类讨论
》》1:要是他爸爸是他的爷爷的左儿子,他是他爸爸的左儿子(或他爸爸是他的爷爷的右儿子,他是他爸爸的右儿子)–>先转一次它爸爸,再转一次他(有利于复杂度,大家可以自己证证)。
》》2、否则就把他转两次。
这样就可以了。
(PS:这并非最好的旋转方法,但绝对不会差很多也不会被卡成n^2,这也是为了省代码)
标准的转法:先判1,在判2;如果不是就比上述方法少了一步先把x转一次的步骤,直接跳到3,一直转直到命题2成立。
上述有一个gets函数,这是用来判是左儿子还是右儿子的(其实就是为了省代码,偷懒)。
int gets(int x)
{
return (ch[fa[x]][1]==x);
}
1表示是右儿子,0表示是左儿子。这么做可以省代码,又十分简洁。
接下来呢,我们来看看rotate操作,这个是splay的核心
> void rotate(int x)
{
int F=fa[x];
int d=gets(x);
int GF=fa[F];
int dd=gets(F);
if(GF)
ch[GF][dd]=x;
ch[F][d]=ch[x][!d];
fa[ch[x][!d]]=F;
ch[x][!d]=F;
fa[x]=GF;
fa[F]=x;
pushup(F);pushup(x);
return;
}
其实就是先判这是左旋还是右旋,然后代码应该看得懂,这样换下儿子换下位置,仍满足BST性质,所以可以这么做,我们旋转这棵树可以防止其退化,并且方便操作。最后别忘记pushup统计下这两个节点的包括自己的子树所含的节点个数。顺序不能换,因为现在F是x的儿子了,且为防止0对结果的影响,S[x]表示这两个节点的包括自己的子树所含的节点个数,num[x]是该重复节点的个数,所以有了以下代码
-
void pushup(int x){S[0]=0;S[x]=S[ch[x][0]]+S[ch[x][1]]+num[x];return;}
现在我们对splay的核心操作已经了如指掌,现在我们可以开始进行操作了,我们支持题中所说的所有操作:
1、insert
void insert(int dd)
{
if(size==0||root==0)
{
num[++size]=1;
S[size]=1;
val[size]=dd;
root=1;
fa[size]=0;
return;
}
int ll=find(dd);
if(val[ll]!=dd)
{
val[++size]=dd;
num[size]=0;
fa[size]=ll;
if(dd<val[ll]) ch[ll][0]=size;else ch[ll][1]=size;
ch[size][0]=ch[size][1]=0;
ll=size;
}
splay(ll,0);
num[ll]++;
pushup(ll);
}
第一句话好理解,要是该树是空的那就造一棵树,自己当根节点。(其实不用特判直接删了都能过)
然后就是找值为dd的点,如果有就直接num[ll]++,没有就创一个新的节点,注意find会自动返回添加节点的父亲。记录一下val值这样就能OK了。
2、del删除操作
void del(int x)
{
int node=find(x);
splay(node,0);
num[node]--;
if(num[node]==0)
{
if(!ch[node][0]&&!ch[node][1])root=0;
else if(!ch[node][0])
{
root=ch[node][1];fa[ch[node][1]]=0;
}
else if(!ch[node][1])
{
root=ch[node][0];fa[ch[node][0]]=0;
}
else
{
int ll=succ(node);
fa[ch[node][0]]=0;
splay(ll,0);
fa[ch[node][1]]=ll;
ch[ll][1]=ch[node][1];
pushup(ll);
}
}
}
删除操作先将其找到,把他旋到根,然后就看他是否有两个儿子,如果他有,那么我们就可以直接做,否则我们先把他删了,分成左、右两颗子树,找到他的前驱,从树中转到根处(这样他一定没有右儿子)再把右子树接到该树上即可。
3、查询一个数的排名,利用size把他旋到树根然后直接统计左子树个数+1(BST性质)即可
int ll=find(it);splay(ll,0);printf("%d\n",S[ch[root][0]]+1);
4、这个也是利用BST的性质
int element(int k)
{
int r=root;
while(1)
{
if(S[ch[r][0]]<k&&S[ch[r][0]]+num[r]>=k) return r;
if(S[ch[r][0]]>=k) r=ch[r][0];
else if(S[ch[r][1]]) {k-=num[r]+S[ch[r][0]];r=ch[r][1];}
}
}
5、6、求前驱,后缀,先把他加入,然后做完前驱(后缀)后将其删除(利用BST性质)。
int succ(int node)
{
int r=ch[node][0];
while(ch[r][1]) r=ch[r][1];
return r;
}
int pred(int node)
{
int r=ch[node][1];
while(ch[r][0]) r=ch[r][0];
return r;
}
以下贴上源代码
#include<bits/stdc++.h>
using namespace std;
#define N 1000001
int num[N],S[N],key[N],fa[N],ch[N][2],val[N];
int n,m;
int opt,it,root,size=0;
void pushup(int x){S[0]=0;S[x]=S[ch[x][0]]+S[ch[x][1]]+num[x];return;}
int gets(int x)
{
return (ch[fa[x]][1]==x);
}
void rotate(int x)
{
int F=fa[x];
int d=gets(x);
int GF=fa[F];
int dd=gets(F);
if(GF)
ch[GF][dd]=x;
ch[F][d]=ch[x][!d];
fa[ch[x][!d]]=F;
ch[x][!d]=F;
fa[x]=GF;
fa[F]=x;
pushup(F);pushup(x);
return;
}
void splay(int x,int ff)
{
for(int F;(F=fa[x])!=ff;rotate(x))
{
if(fa[F])
{
if(!((gets(F))^(gets(x)))) rotate(F);
else rotate(x);
}
}
if(!ff) root=x;
}
int find(int dd)
{
int r=root;
while(val[r]!=dd)
{
if(dd<val[r])
{
if(!ch[r][0]) break;
r=ch[r][0];
}
else
{
if(!ch[r][1]) break;
r=ch[r][1];
}
}
return r;
}
void insert(int dd)
{
int ll=find(dd);
if(val[ll]!=dd)
{
val[++size]=dd;
num[size]=0;
fa[size]=ll;
if(dd<val[ll]) ch[ll][0]=size;else ch[ll][1]=size;
ch[size][0]=ch[size][1]=0;
ll=size;
}
splay(ll,0);
num[ll]++;
pushup(ll);
}
int succ(int node)
{
int r=ch[node][0];
while(ch[r][1]) r=ch[r][1];
return r;
}
int pred(int node)
{
int r=ch[node][1];
while(ch[r][0]) r=ch[r][0];
return r;
}
void del(int x)
{
int node=find(x);
splay(node,0);
num[node]--;
if(num[node]==0)
{
if(!ch[node][0]&&!ch[node][1])root=0;
else if(!ch[node][0])
{
root=ch[node][1];fa[ch[node][1]]=0;
}
else if(!ch[node][1])
{
root=ch[node][0];fa[ch[node][0]]=0;
}
else
{
int ll=succ(node);
fa[ch[node][0]]=0;
splay(ll,0);
fa[ch[node][1]]=ll;
ch[ll][1]=ch[node][1];
pushup(ll);
}
}
}
int element(int k)
{
int r=root;
while(1)
{
if(S[ch[r][0]]<k&&S[ch[r][0]]+num[r]>=k) return r;
if(S[ch[r][0]]>=k) r=ch[r][0];
else if(S[ch[r][1]]) {k-=num[r]+S[ch[r][0]];r=ch[r][1];}
}
}
int main()
{
scanf("%d",&n);
while(n--)
{
scanf("%d%d",&opt,&it);
if(opt==1){insert(it);}
else if(opt==2){del(it);}
else if(opt==3){int ll=find(it);splay(ll,0);printf("%d\n",S[ch[root][0]]+1);}
else if(opt==4){int ll=element(it);printf("%d\n",val[ll]);splay(ll,0);}
else if(opt==5){insert(it);printf("%d\n",val[succ(find(it))]);del(it);}
else {insert(it);printf("%d\n",val[pred(find(it))]);del(it);}
}
}
精简后的代码
#include<bits/stdc++.h>
using namespace std;
#define N 1000001
int num[N],S[N],key[N],fa[N],ch[N][2],val[N];int n,m;int opt,it,root,size=0;
void pushup(int x){S[0]=0;S[x]=S[ch[x][0]]+S[ch[x][1]]+num[x];return;}
int gets(int x){return (ch[fa[x]][1]==x);}
void rotate(int x)
{
int F=fa[x];int d=gets(x);int GF=fa[F];int dd=gets(F);
if(GF)
ch[GF][dd]=x;
ch[F][d]=ch[x][!d];
fa[ch[x][!d]]=F;
ch[x][!d]=F;
fa[x]=GF;
fa[F]=x;
pushup(F);pushup(x);
return;
}
void splay(int x,int ff)
{
for(int F;(F=fa[x])!=ff;rotate(x))
if(fa[F]){if(!((gets(F))^(gets(x)))) rotate(F); else rotate(x);}
if(!ff) root=x;
}
int find(int dd)
{
int r=root;
while(val[r]!=dd)
{
if(dd<val[r]){if(!ch[r][0]) break;r=ch[r][0];}
else{if(!ch[r][1]) break;r=ch[r][1];}
}
return r;
}
void insert(int dd)
{
int ll=find(dd);
if(val[ll]!=dd)
{
val[++size]=dd;
fa[size]=ll;
if(dd<val[ll]) ch[ll][0]=size;else ch[ll][1]=size;
ll=size;
}
splay(ll,0);
num[ll]++;
pushup(ll);
}
int succ(int node)
{
int r=ch[node][0];
while(ch[r][1]) r=ch[r][1];
return r;
}
int pred(int node)
{
int r=ch[node][1];
while(ch[r][0]) r=ch[r][0];
return r;
}
void del(int x)
{
int node=find(x);splay(node,0);num[node]--;
if(num[node]==0)
{
if(!ch[node][0]&&!ch[node][1])root=0;
else if(!ch[node][0]){root=ch[node][1];fa[ch[node][1]]=0;}
else if(!ch[node][1]){root=ch[node][0];fa[ch[node][0]]=0;}
else
{
int ll=succ(node);
fa[ch[node][0]]=0;
splay(ll,0);
fa[ch[node][1]]=ll;
ch[ll][1]=ch[node][1];
pushup(ll);
}
}
}
int element(int k)
{
int r=root;
while(1)
{
if(S[ch[r][0]]<k&&S[ch[r][0]]+num[r]>=k) return r;
if(S[ch[r][0]]>=k) r=ch[r][0];
else if(S[ch[r][1]]) {k-=num[r]+S[ch[r][0]];r=ch[r][1];}
}
}
int main()
{
scanf("%d",&n);
while(n--)
{
scanf("%d%d",&opt,&it);
if(opt==1){insert(it);}
else if(opt==2){del(it);}
else if(opt==3){int ll=find(it);splay(ll,0);printf("%d\n",S[ch[root][0]]+1);}
else if(opt==4){int ll=element(it);printf("%d\n",val[ll]);splay(ll,0);}
else if(opt==5){insert(it);printf("%d\n",val[succ(find(it))]);del(it);}
else {insert(it);printf("%d\n",val[pred(find(it))]);del(it);}
}
}