rotate和splay跳过
insert就是insert,不要忘了splay。
find写成找第一个>=v的。
前驱后继就是find后先往左右跳一步然后能往右左走就走
del就是find后把前驱后继splay的根和根的右儿子删除即可。
getk就是find后splay到根然后左子树大小+1
getkth就是在上面二分即可。
为了保证存在前驱后继可以先insert(±INF)
代码:
#include<iostream>
#include<cstring>
#include<cstdio>
#include<climits>
#include<algorithm>
#define N 100010
#define inf INT_MIN
#define INF INT_MAX
using namespace std;
struct Splay{
int cnt[N],val[N],sz[N],ch[N][2],fa[N];
int node_cnt,root,inf_pos,INF_pos;
Splay()
{
node_cnt=0;root=0;
fa[0]=ch[0][0]=ch[0][1]=sz[0]=val[0]=cnt[0]=0;
}
inline int push_up(int x)
{
return sz[x]=sz[ch[x][0]]+sz[ch[x][1]]+cnt[x];
}
inline int new_node(int _v=0)
{
int x=++node_cnt;
cnt[x]=sz[x]=1;val[x]=_v;
ch[x][0]=ch[x][1]=fa[x]=0;
return x;
}
inline int getwh(int x)
{
return ch[fa[x]][1]==x;
}
inline int set_child(int x,int y,int wh)
{
if(!x) return fa[y]=0;
ch[x][wh]=y;if(y) fa[y]=x;
push_up(x);return 0;
}
inline int rotate(int x)
{
int y=fa[x],z=fa[y],a=getwh(x),b=getwh(y);
set_child(y,ch[x][a^1],a);
set_child(x,y,a^1);
if(z) set_child(z,x,b);
else root=x,fa[x]=0;return x;
}
inline int splay(int x,int tar=0)
{
if(!x) return 0;
while(fa[x]!=tar)
{
if(fa[fa[x]]!=tar)
if(getwh(x)^getwh(fa[x])) rotate(x);
else rotate(fa[x]);
rotate(x);
}
if(!tar) root=x;return x;
}
inline int insert(int v)
{
int now=root,last=0,x;
while(now)
{
last=now;
if(v<val[now]) now=ch[now][0];
else if(v>val[now]) now=ch[now][1];
else break;
}
if(now) cnt[now]++,sz[now]++,splay(x=now);
else x=new_node(v),set_child(last,x,val[x]>val[last]),splay(x);
return x;
}
inline int find(int v)
{
int now=root,x=INF_pos;
while(now)
if(val[now]>=v) now=ch[x=now][0];
else now=ch[now][1];
splay(x);return x;
}
inline int pre(int v)
{
int x=find(v);x=ch[x][0];
while(ch[x][1]) x=ch[x][1];
return x;
}
inline int post(int v)
{
int x=find(v);
if(v<val[x]) return x;
x=ch[x][1];
while(ch[x][0]) x=ch[x][0];
return x;
}
inline int del_node(int x)
{
cnt[x]--;sz[x]--;return 0;
}
inline int delete_splay(int _v)
{
int x=find(_v),y=pre(val[x]),z=post(val[x]);
if(val[x]!=_v) return 0;
del_node(x);if(cnt[x]) return 0;
splay(y),splay(z,y),set_child(z,0,0),push_up(y);
return 0;
}
inline int get_rank(int v)
{
int x=find(v);return sz[ch[x][0]]+(val[x]==v);
}
inline int get_kth(int k)
{
int now=root;
while(now)
if(k<=sz[ch[now][0]]) now=ch[now][0];
else if(k>sz[ch[now][0]]+cnt[now]) k-=sz[ch[now][0]]+cnt[now],now=ch[now][1];
else break;
splay(now);return now;
}
inline int init_splay()
{
inf_pos=insert(inf);sz[inf_pos]=cnt[inf_pos]=0;
INF_pos=insert(INF);sz[INF_pos]=cnt[INF_pos]=0;
return 0;
}
}spl;
inline int inn()
{
char ch;int x=0;bool f=false;
while(((ch=getchar())^'-')&&(ch<'0'||ch>'9'));
if(ch^'-') x=ch^'0';else f=true;
while((ch=getchar())>='0'&&ch<='9')
x=(x<<1)+(x<<3)+(ch^'0');
return f?-x:x;
}
int main()
{
spl.init_splay();
int q;scanf("%d",&q);
while(q--)
{
int opt;scanf("%d",&opt);
switch(opt)
{
case 1:spl.insert(inn());break;
case 2:spl.delete_splay(inn());break;
case 3:printf("%d\n",spl.get_rank(inn()));break;
case 4:printf("%d\n",spl.val[spl.get_kth(inn())]);break;
case 5:printf("%d\n",spl.val[spl.pre(inn())]);break;
case 6:printf("%d\n",spl.val[spl.post(inn())]);break;
}
}
return 0;
}