BZOJ 3224 普通平衡树
Description
您需要写一种数据结构(可参考题目标题),来维护一些数,其中需要提供以下操作:
1. 插入x数
2. 删除x数(若有多个相同的数,因只删除一个)
3. 查询x数的排名(若有多个相同的数,因输出最小的排名)
4. 查询排名为x的数
5. 求x的前驱(前驱定义为小于x,且最大的数)
6. 求x的后继(后继定义为大于x,且最小的数)
Input
第一行为n,表示操作的个数,下面n行每行有两个数opt和x,opt表示操作的序号(1<=opt<=6)
Output
对于操作3,4,5,6每行输出一个数,表示对应答案
Sample Input
10
1 106465
4 1
1 317721
1 460929
1 644985
1 84185
1 89851
6 81968
1 492737
5 493598
Sample Output
106465
84185
492737
HINT
1.n的数据范围:n<=100000
2.每个数的数据范围:[-1e7,1e7]
Solution
首先来说一下splay的基本操作:
- insert
和二叉搜索树的插入一样
只不过要把当前加入的节点旋转到根
void insert(int &k,int x,int lst)
{
if(!k)
{
k=++size;t[k].fa=lst;
t[k].v=x;t[k].size=t[k].w=1;
return;
}
t[k].size++;
if(t[k].v==x)t[k].w++;
else if(t[k].v>x)insert(t[k].l,x,k);
else insert(t[k].r,x,k);
}
- delete
把x的前驱和后继找出来
然后把前驱旋转到根,后继旋转到根的右节点
根的右节点左子树只有一个元素就是x(仔细思考为什么)
之后如果该节点的w值(x相同的元素的个数)大于1则直接减一
否则在树中删除它
void del(int &k)
{
splay(k,root);
if(t[root].w>1){t[root].w--,t[root].size--;return;}
if(!(t[root].l&&t[root].r)){root=t[root].l+t[root].r;t[root].fa=0;return;}
ans1=ans2=0;
get_pre(root,t[root].v);
get_pro(root,t[root].v);
splay(ans1,root);
splay(ans2,t[root].r);
int pos=t[t[root].r].l;
t[pos].l=t[pos].r=t[pos].size=t[pos].w=t[pos].v=t[pos].fa=0;
t[t[root].r].l=0;
update(t[root].r),update(root);
}
get_rank(询问x的排名)
先找到那个点然后旋转到根上
输出旋转后树的左子树size+1get_x(询问排名为rank的数)
从根开始递归查找
分三种情况:- 比当前节点左子树的元素个数小或等于
此时查找左节点 - 比当前节点元素个数+左子树元素个数大
此时查找右子树中排名为(x-左子树元素个数-当前节点元素个数)的数 - 等于当前节点
答案就是当前节点
- 比当前节点左子树的元素个数小或等于
int get_x(int k,int x)
{
if(!k)return k;
if(t[t[k].l].size>=x)return get_x(t[k].l,x);
else if(t[t[k].l].size+t[k].w<x)return get_x(t[k].r,x-t[t[k].l].size-t[k].w);
else return k;
}
- get_pre(找x的前驱)
和二叉搜索树的查找方法一样
void get_pre(int k,int x)
{
if(!k)return;
if(t[k].v<x){ans1=k;get_pre(t[k].r,x);}
else get_pre(t[k].l,x);
}
- get_suf(找x的后继)
和二叉搜索树的查找方法一样
void get_suf(int k,int x)
{
if(!k)return;
if(t[k].v>x){ans2=k;get_suf(t[k].l,x);}
else get_suf(t[k].r,x);
}
说了这么多大家可能觉得splay就相当于普通的二叉搜索树
但是splay有旋转操作可以使树变得平均
旋转有两种,单旋和双旋:
单旋容易写但是可以卡成每次O(n)
双旋难写但是可以使时间均摊到每次O(logn)
在这里重点介绍双旋
双旋有三种情况
1.如果x的父节点是根,那么直接旋上去
2.设x的父节点为y,y的父节点为z,y不为根节点,且x,y,z排列成一字形那么我们先旋转y到根节点,再旋转x到根节点
3.设x的父节点为y,y的父节点为z,y不为根节点,且x,y,z排列成之字形那么我们旋转两次x到根节点
证明复杂度
splay差不多就这些
上3224代码
#include<cstdio>
#include<cstring>
#include<algorithm>
#include<cmath>
using namespace std;
int root,ans1,ans2,size;
struct splay{
int l,r,size,w,v,fa;
}t[200001];
inline int in()
{
int x=0,f=1;char ch=getchar();
while((ch<'0'||ch>'9')&&ch!='-')ch=getchar();
if(ch=='-')ch=getchar(),f=-1;
while(ch<='9'&&ch>='0')x=x*10+ch-'0',ch=getchar();
return x*f;
}
void update(int k)
{
t[k].size=t[t[k].l].size+t[t[k].r].size+t[k].w;
}
void raxe(int x,int &k)
{
int y,z;
y=t[x].fa,z=t[y].fa;
if(!z)k=x;
else {if(t[z].l==y)t[z].l=x;else t[z].r=x;}
t[x].fa=z,t[y].fa=x,t[t[x].r].fa=y;
t[y].l=t[x].r,t[x].r=y;
update(y);update(x);
}
void laxe(int x,int &k)
{
int y,z;
y=t[x].fa,z=t[y].fa;
if(!z)k=x;
else {if(t[z].l==y)t[z].l=x;else t[z].r=x;}
t[x].fa=z,t[y].fa=x,t[t[x].l].fa=y;
t[y].r=t[x].l,t[x].l=y;
update(y);update(x);
}
void splay(int x,int &k)
{
if(!x)return;
int y,z;
while(x!=k)
{
y=t[x].fa;
z=t[y].fa;
if(y!=k)
{
if(t[y].l==x&&t[z].l==y)raxe(y,k);
else if(t[y].r==x&&t[z].r==y)laxe(y,k);
else if(t[y].r==x&&t[z].l==y)laxe(x,k);
else raxe(x,k);
}
if(t[t[x].fa].l==x)raxe(x,k);
else laxe(x,k);
}
}
void insert(int &k,int x,int lst)
{
if(!k)
{
k=++size;t[k].fa=lst;
t[k].v=x;t[k].size=t[k].w=1;
return;
}
t[k].size++;
if(t[k].v==x)t[k].w++;
else if(t[k].v>x)insert(t[k].l,x,k);
else insert(t[k].r,x,k);
}
void get_pre(int k,int x)
{
if(!k)return;
if(t[k].v<x){ans1=k;get_pre(t[k].r,x);}
else get_pre(t[k].l,x);
}
void get_pro(int k,int x)
{
if(!k)return;
if(t[k].v>x){ans2=k;get_pro(t[k].l,x);}
else get_pro(t[k].r,x);
}
int get(int k,int x)
{
if(!k)return k;
if(t[k].v>x)return get(t[k].l,x);
else if(t[k].v==x)return k;
else return get(t[k].r,x);
}
void del(int &k)
{
splay(k,root);
if(t[root].w>1){t[root].w--,t[root].size--;return;}
if(!(t[root].l&&t[root].r)){root=t[root].l+t[root].r;t[root].fa=0;return;}
ans1=ans2=0;
get_pre(root,t[root].v);
get_pro(root,t[root].v);
splay(ans1,root);
splay(ans2,t[root].r);
int pos=t[t[root].r].l;
t[pos].l=t[pos].r=t[pos].size=t[pos].w=t[pos].v=t[pos].fa=0;
t[t[root].r].l=0;
update(t[root].r),update(root);
}
int GET(int k,int x)
{
if(!k)return k;
if(t[t[k].l].size>=x)return GET(t[k].l,x);
else if(t[t[k].l].size+t[k].w<x)return GET(t[k].r,x-t[t[k].l].size-t[k].w);
else return k;
}
int main()
{
freopen("4543.in","r",stdin);
freopen("4543.out","w",stdout);
int n,anss;
n=in();
for(int i=1;i<=n;i++)
{
int doing=in(),x=in();
if(doing==1)insert(root,x,0);
else if(doing==2)ans1=0,anss=get(root,x),del(anss);
else if(doing==3)ans1=get(root,x),splay(ans1,root),printf("%d\n",t[t[root].l].size+1);
else if(doing==4)ans1=GET(root,x),printf("%d\n",t[ans1].v),splay(ans1,root);
else if(doing==5)get_pre(root,x),printf("%d\n",t[ans1].v),splay(ans1,root);
else get_pro(root,x),printf("%d\n",t[ans2].v),splay(ans2,root);
}
return 0;
}