题意不赘述,不懂平衡树的自己书补,网补,脑补。脑补,蛤。
直接贴SBT代码,然后另带一份SplayTree的TLE代码。
SBT:
#include<cstdio>
#include<cstring>
#define N 2001000
using namespace std;
struct SBT
{
int key,cnt;
int l,r,size,sc;
}s[N];
int root,n,m;
void left_rotate(int &x)
{
int y=s[x].r;
s[x].r=s[y].l;
s[y].l=x;
s[y].size=s[x].size;
s[y].sc=s[x].sc;
s[x].size=s[s[x].l].size+s[s[x].r].size+1;
s[x].sc=s[s[x].l].sc+s[s[x].r].sc+s[x].cnt;
x=y;
}
void right_rotate(int &x)
{
int y=s[x].l;
s[x].l=s[y].r;
s[y].r=x;
s[y].size=s[x].size;
s[y].sc=s[x].sc;
s[x].size=s[s[x].l].size+s[s[x].r].size+1;
s[x].sc=s[s[x].l].sc+s[s[x].r].sc+s[x].cnt;
x=y;
}
void keep(int &x,int flag)
{
s[x].size=s[s[x].l].size+s[s[x].r].size+1;
s[x].sc=s[s[x].l].sc+s[s[x].r].sc+s[x].cnt;
if(!flag)
{
if(s[s[s[x].l].l].size>s[s[x].r].size)right_rotate(x);
else if(s[s[s[x].l].r].size>s[s[x].r].size)left_rotate(s[x].l),right_rotate(x);
else return ;
}
else
{
if(s[s[s[x].r].r].size>s[s[x].l].size)left_rotate(x);
else if(s[s[s[x].r].l].size>s[s[x].l].size)right_rotate(s[x].r),left_rotate(x);
else return ;
}
keep(s[x].l,0);
keep(s[x].r,1);
keep(x,0);
keep(x,1);
}
void insert(int &x,int key)
{
if(!x)
{
x=++n;
s[x].size=s[x].sc=s[x].cnt=1;
s[x].key=key;
}
else
{
if(key==s[x].key)s[x].cnt++,s[x].sc++;
else if(key<s[x].key)insert(s[x].l,key),keep(x,0);
else insert(s[x].r,key),keep(x,1);
}
}
void del(int &x,int key)
{
if(s[x].key==key)
{
if(s[x].cnt>1)s[x].cnt--,s[x].sc--;
else if(!s[x].l&&!s[x].r)x=0;
else if(!s[x].l*s[x].r)x=s[x].l+s[x].r;
else if(s[s[x].l].size>s[s[x].r].size)right_rotate(x),del(s[x].r,key),keep(x,0);
else left_rotate(x),del(s[x].l,key),keep(x,1);
}
else if(key<s[x].key)del(s[x].l,key),keep(x,1);
else del(s[x].r,key),keep(x,0);
}
int find(int &x,int k)
{
if(s[s[x].l].sc<k&&k<=s[s[x].l].sc+s[x].cnt)return s[x].key;
if(k<=s[s[x].l].sc)return find(s[x].l,k);
else return find(s[x].r,k-s[s[x].l].sc-s[x].cnt);
}
int rank(int &x,int key)
{
if(s[x].key==key)return s[s[x].l].sc+1;
if(key<s[x].key)return rank(s[x].l,key);
else return rank(s[x].r,key)+s[x].cnt+s[s[x].l].sc;
}
int getmin()
{
int x=root;
for(;s[x].l;x=s[x].l);
return s[x].key;
}
int getmax()
{
int x=root;
for(;s[x].r;x=s[x].r);
return s[x].key;
}
int pred(int &x,int y,int key)
{
if(!x)return y;
if(key<=s[x].key)return pred(s[x].l,y,key);
else return pred(s[x].r,x,key);
}
int succ(int &x,int y,int key)
{
if(!x)return y;
if(key<s[x].key)return succ(s[x].l,x,key);
else return succ(s[x].r,y,key);
}
int main()
{
int i,t,x;
scanf("%d",&m);
for(i=1;i<=m;i++)
{
scanf("%d%d",&t,&x);
switch(t)
{
case 1:insert(root,x);break;
case 2:del(root,x);break;
case 3:printf("%d\n",rank(root,x));break;
case 4:printf("%d\n",find(root,x));break;
case 5:printf("%d\n",s[pred(root,0,x)].key);break;
case 6:printf("%d\n",s[succ(root,0,x)].key);break;
}
}
return 0;
}
SPT:(思想是对的。rotate和splay应该都是对的。)
#include <cstdio>
#include <algorithm>
#define N 201000
#define inf 0x3f3f3f3f
#define rt son[root][1]
#define lrt son[rt][0]
#define ls son[x][0]
#define rs son[x][1]
#define is(x) (x==son[fa[x]][1])
using namespace std;
struct SPT
{
int root,top;
int val[N],son[N][2],size[N],num[N],fa[N];
inline void link(int &x,int y,int d){son[y][d]=x;fa[x]=y;}
inline void pushup(int x)
{
size[x]=size[ls]+size[rs]+num[x];
}
inline void init()
{
top=0;
newnode(root,0,-inf);
newnode(son[root][1],root,inf);
size[1]=size[2]=num[1]=num[2]=0;
}
inline void rotate(int x)
{
int y=fa[x],z=fa[y],idx=is(x),idy=is(y);
link(son[x][!idx],y,idx);
link(y,x,!idx);
if(z)link(x,z,idy);fa[x]=z;
pushup(y);
pushup(x);
}
inline void splay(int x,int k=0)
{
int y,z;
while(fa[x]!=k)
{
y=fa[x],z=fa[y];
if(z==k){rotate(x);break;}
if(is(x)==is(y))rotate(y);
else rotate(x);
rotate(x);
}
if(!k)root=x;
}
inline int pred(int w,int k=0)
{
int x=root,y;
while(x)
{
if(w>val[x])y=x;
x=son[x][w>val[x]];
}
splay(y,k);
return val[y];
}
inline int succ(int w,int k=0)
{
int x=root,y;
while(x)
{
if(w<val[x])y=x;
x=son[x][w>=val[x]];
}
splay(y,k);
return val[y];
}
inline void newnode(int &x,int y,int w)
{
x=++top;
val[x]=w;
fa[x]=y;
size[x]=num[x]=1;
}
inline void insert(int w,int k=0)
{
int x=root;
while(son[x][w>val[x]])
{
if(w==val[x])
{
num[x]++;
splay(x,k);
return ;
}
x=son[x][w>val[x]];
}
newnode(son[x][w>val[x]],x,w);
splay(son[x][w>val[x]]);
}
inline bool remove(int w)
{
pred(w);succ(w,root);
if(!lrt)return 0;
if(num[lrt]-1)num[lrt]--;
else son[rt][0]=0;
pushup(rt);
pushup(root);
return 1;
}
inline int rank(int w,int k=0)
{
int x=root;
while(son[x][w>val[x]])
{
if(w==val[x])break;
x=son[x][w>val[x]];
}
splay(x,k);
return size[son[x][0]]+1;
}
inline int find(int rank,int k=0)
{
int x=root;
while(rank<=size[son[x][0]]||size[son[x][0]]+num[x]<rank)
{
if(rank<=size[son[x][0]])x=son[x][0];
else x=son[x][1],rank-=(size[son[x][0]]+num[x]);
}
splay(x,k);
return val[x];
}
inline void dfs(int x)
{
if(son[x][0])dfs(son[x][0]);
if(val[x]>0&&val[x]<inf)for(int i=1;i<=num[x];i++)printf("%d\n",val[x]);
if(son[x][1])dfs(son[x][1]);
}
}spt;
int main()
{
// freopen("test.in","r",stdin);
int i,t,k,n;
scanf("%d",&n);
spt.init();
for(i=1;i<=n;i++)
{
scanf("%d",&t);
spt.insert(t);
}
spt.dfs(spt.root);
return 0;
}