题目链接:
洛谷 P3369 【模板】普通平衡树(Treap/SBT)
BZOJ 3224: Tyvj 1728 普通平衡树
第一次尝试
第一次splay板子是大佬教给我的,全部用指针完成了splay的基本操作。当时的我码力还是不足,调试了半天极其低级的错误,当时的我还把他们记载在下面:
- [attempt 1] CE 由于构造函数未在结构体里声明
struct node
{
int val,siz,cnt;
node *son[2],*fa;
node(const int &k); //<-this line.
int dir(){return fa->son[1]==this;}
void upd(){siz=son[0]->siz+son[1]->siz+cnt;}
}*nil=new node(0),*RT,*flag;
node::node(const int &k)
{
val=k,siz=cnt=1,son[0]=son[1]=fa=nil;
}
- [attempt 2] WA rotate函数错误
void rotate(node *rt,int d)
{
node *t=rt->son[d^1];
rt->son[d^1]=t->son[d];
if(rt->son[d^1]!=nil)rt->son[d^1]->fa=rt;
t->son[d]=rt; //<-this line.
if(rt->fa!=nil)rt->fa->son[rt->dir()]=t;
t->fa=rt->fa;rt->fa=t;
rt->upd();t->upd();
return;
}
- [attempt 3] WA delete函数错误
void del(int x)
{
node *l=lower(RT,x),*r=upper(RT,x);
if(l==nil&&r==nil)
{
if(RT->cnt==1)RT=nil;
else RT->cnt--,RT->siz--;
return;
}
if(l==nil&&r!=nil)
{
splay(r,nil);
if(RT->son[0]->cnt==1)RT->son[0]=nil;
else RT->son[0]->cnt--,RT->son[0]->siz--;
RT->upd();
return;
}
if(r==nil)
{
splay(l,nil);
if(RT->son[1]->cnt==1)RT->son[1]=nil;
else RT->son[1]->cnt--,RT->son[1]->siz--;
RT->upd();
return;
}
splay(l,nil);
splay(r,RT);
node *obj=RT->son[1]->son[0];
if(obj->cnt==1)RT->son[1]->son[0]=nil;
else obj->cnt--,obj->siz--;
RT->son[1]->upd();
RT->upd();
return;
}
这里我调试了很久,比如虚拟节点不能update更新子树大小,再比如删除时只更新了临时变量,没有更改其父亲的儿子指针等等。
第二次尝试
这一次的我经历了一年多的沉淀然而还是那么菜,选择了用数组完成splay这个数据结构。
#include<cstdio>
#include<iostream>
#include<cstring>
using namespace std;
inline int read(){
int x=0,y=1;char c=getchar();
while(!isdigit(c)){if(c=='-')y=-y;c=getchar();}
while(isdigit(c)){x=x*10+c-'0';c=getchar();}
return x*y;
}
const int INF = 100000000;
const int MAXN = 100010;
int son[MAXN][2],fa[MAXN];
int siz[MAXN],rev[MAXN],cnt[MAXN],key[MAXN];
int tot,root;
int getd(int x){return son[fa[x]][1] == x;}
int push_up(int x){
siz[x] = siz[son[x][0]] + siz[son[x][1]] + cnt[x];
return 0;
}
void push_down(int x){
if(rev[x]){
rev[x]^=1,rev[son[x][0]]^=1,rev[son[x][1]]^=1;
swap(son[x][0],son[x][1]);
}
}
void rotate(int x){
int fat=fa[x],gra=fa[fat],dir=getd(x),fad=getd(fat);
son[fat][dir]=son[x][dir^1];
fa[son[x][dir^1]]=fat;
son[x][dir^1]=fat;
fa[fat]=x;
if(gra)son[gra][fad]=x;
fa[x]=gra;
push_up(fat);
push_up(x);
}
int splay(int x,int to){
while(fa[x]!=to){
if(fa[fa[x]]!=to&&getd(fa[x])==getd(x))rotate(fa[x]);
rotate(x);
}
if(to==0)root=x;
return 0;
}
void init(){
key[1]=-INF,key[2]=INF;
siz[1]=1,siz[2]=2;
cnt[1]=cnt[2]=1;
fa[1]=2,son[2][0]=1;
root=tot=2;
}
int insert(int now,int fat,int dir,int k){
if(now==0){
fa[++tot]=fat,son[fat][dir]=tot;
cnt[tot]=siz[tot]=1;
key[tot]=k;
return tot;
}
if(key[now]==k){
siz[now]++;
cnt[now]++;
return now;
}
if(key[now]>k)return insert(son[now][0],now,0,k)+push_up(now);
if(key[now]<k)return insert(son[now][1],now,1,k)+push_up(now);
return 0;
}
int Find(int now,int x){
if(now==0)return 0;
if(key[now]==x)return now;
if(key[now]>x)return Find(son[now][0],x);
if(key[now]<x)return Find(son[now][1],x);
}
int get_rank(int now,int k){
if(now==0)return printf("error\n");
if(key[now]==k)return siz[son[now][0]]+1;
if(key[now]>k)return get_rank(son[now][0],k);
if(key[now]<k)return get_rank(son[now][1],k)+siz[now]-siz[son[now][1]];
return 0;
}
int find_rank(int now,int k){
if(now==0)return 0;
const int L=siz[son[now][0]],N=cnt[now],R=siz[son[now][1]];
if(k<=L)return find_rank(son[now][0],k);
if(L<k&&k<=L+N){splay(now,0);return key[now];}
if(k>L+N)return find_rank(son[now][1],k-L-N);
}
inline int max_key(int x,int y){key[0]=-INF;return key[x]>key[y]?x:y;}
int prev(int now,int k){
if(now==0)return 0;
if(key[now]>=k)return prev(son[now][0],k);
return max_key(prev(son[now][1],k),now);
}
inline int min_key(int x,int y){key[0]=INF;return key[x]<key[y]?x:y;}
int next(int now,int k){
if(now==0)return 0;
if(key[now]<=k)return next(son[now][1],k);
return min_key(next(son[now][0],k),now);
}
int Delete(int k){
splay(prev(root,k),0);
splay(next(root,k),root);
int t=son[son[root][1]][0];
if(cnt[t]>1){cnt[t]--;siz[t]--;}
else son[son[root][1]][0]=0;
push_up(son[root][1]);
push_up(root);
}
int main(){
init();
int n=read();
while(n--){
int opt=read(),x=read();
switch(opt){
case 1:{
splay(insert(root,0,0,x),0);
break;
}
case 2:{
Delete(x);
break;
}
case 3:{
printf("%d\n",get_rank(root,x)-1);
splay(Find(root,x),0);
break;
}
case 4:{
printf("%d\n",find_rank(root,x+1));
break;
}
case 5:{
printf("%d\n",key[prev(root,x)]);
break;
}
case 6:{
printf("%d\n",key[next(root,x)]);
break;
}
}
//for(int i=1;find_rank(root,i)!=INF;i++){
// printf("%d ",find_rank(root,i));
//}
//printf("\n");
}
}
关于splay的删除,网上的版本不一,但经过我的搜索,最快的应该是下面这种:
可以说是非常稳了……
代码如下:
#include<bits/stdc++.h>
using namespace std;
struct node
{
int val,siz,cnt;
node *son[2],*fa;
node(const int &k);
int dir(){return fa->son[1]==this;}
void upd(){siz=son[0]->siz+son[1]->siz+cnt;}
}*nil=new node(0),*RT,*flag;
node::node(const int &k)
{
val=k,siz=cnt=1,son[0]=son[1]=fa=nil;
}
void clear()
{
nil->siz=nil->cnt=0,RT=nil;
return;
}
void rotate(node *rt,int d)
{
node *t=rt->son[d^1];
rt->son[d^1]=t->son[d];
if(rt->son[d^1]!=nil)rt->son[d^1]->fa=rt;
t->son[d]=rt;
if(rt->fa!=nil)rt->fa->son[rt->dir()]=t;
t->fa=rt->fa;rt->fa=t;
rt->upd();t->upd();
return;
}
void splay(node *rt,node *to)
{
while(rt->fa!=to)
{
if(rt->fa->fa!=to&&rt->dir()==rt->fa->dir())
rotate(rt->fa->fa,rt->dir()^1);
rotate(rt->fa,rt->dir()^1);
}
if(to==nil)RT=rt;
return;
}
void add(node *&rt,node *fa,int x)
{
if(rt==nil)
{
rt=new node(x);
rt->fa=fa;
flag=rt;
return;
}
if(rt->val==x){rt->cnt++;flag=rt;}
else if(rt->val>x){add(rt->son[0],rt,x);}
else {add(rt->son[1],rt,x);}
rt->upd();
return;
}
void insert(int x)
{
add(RT,nil,x);
splay(flag,nil);
return;
}
node *lower(node *rt,int x)
{
if(rt==nil)return nil;
if(rt->val>=x)return lower(rt->son[0],x);
else{
node *t=lower(rt->son[1],x);
return t==nil?rt:t;
}
}
node *upper(node *rt,int x)
{
if(rt==nil)return nil;
if(rt->val<=x)return upper(rt->son[1],x);
else
{
node *t=upper(rt->son[0],x);
return t==nil?rt:t;
}
}
void find(node *rt,int x){
if(rt==nil)return;
if(rt->val==x){flag=rt;return;}
if(rt->val>x)find(rt->son[0],x);
else find(rt->son[1],x);
return;
}
void del(int x)
{
node *l=lower(RT,x),*r=upper(RT,x);
if(l==nil&&r==nil)
{
if(RT->cnt==1)RT=nil;
else RT->cnt--,RT->siz--;
return;
}
if(l==nil&&r!=nil)
{
splay(r,nil);
if(RT->son[0]->cnt==1)RT->son[0]=nil;
else RT->son[0]->cnt--,RT->son[0]->siz--;
RT->upd();
return;
}
if(r==nil)
{
splay(l,nil);
if(RT->son[1]->cnt==1)RT->son[1]=nil;
else RT->son[1]->cnt--,RT->son[1]->siz--;
RT->upd();
return;
}
splay(l,nil);
splay(r,RT);
node *obj=RT->son[1]->son[0];
if(obj->cnt==1)RT->son[1]->son[0]=nil;
else obj->cnt--,obj->siz--;
RT->son[1]->upd();
RT->upd();
return;
}
int rand(int x)
{
node *rt=RT;int res=1;
while(rt->val!=x)
{
if(rt->val>x)rt=rt->son[0];
else res=res+rt->son[0]->siz+rt->cnt,rt=rt->son[1];
}
res+=rt->son[0]->siz;
return res;
}
int rand(node *rt,int x)
{
if(rt->son[0]->siz>=x)return rand(rt->son[0],x);
if(rt->son[0]->siz+rt->cnt>=x)return rt->val;
return rand(rt->son[1],x-rt->cnt-rt->son[0]->siz);
}
int read()
{
int x=0,y=1;char c=getchar();
while(!isdigit(c)){if(c=='-')y=-y;c=getchar();}
while(isdigit(c))x=x*10+c-'0',c=getchar();
return x*y;
}
int main()
{
clear();
int n=read();
while(n--)
{
int opt=read(),x=read();
switch(opt)
{
case 1:
{
insert(x);
break;
}
case 2:
{
del(x);
break;
}
case 3:
{
printf("%d\n",rand(x));
break;
}
case 4:
{
printf("%d\n",rand(RT,x));
break;
}
case 5:
{
printf("%d\n",lower(RT,x)->val);
break;
}
case 6:
{
printf("%d\n",upper(RT,x)->val);
break;
}
}
}
return 0;
}
/*
10
1 1
4 1
1 3
1 4
*/