#include<cstdio>
#define N 100010
#define which(x) (ls[f[(x)]]==(x))
using namespace std;
int n,m,ls[N],rs[N],val[N],sze[N],cnt[N],f[N],root,idx;
int read()
{
int ans=0,fu=1;
char j=getchar();
for (;j<'0' || j>'9';j=getchar()) if (j=='-') fu=-1;
for (;j>='0' && j<='9';j=getchar()) ans*=10,ans+=j-'0';
return ans*fu;
}
void updt(int x)
{
sze[x]=sze[ls[x]]+sze[rs[x]]+1;
}
void Rotate(int u)
{
int v=f[u],w=f[v],b=which(u)?rs[u]:ls[u];
if (w) which(v)?rs[w]=u:ls[w]=u;
which(u)?(ls[v]=b,rs[u]=v):(rs[v]=b,ls[u]=v);
f[u]=w,f[v]=u;
if (b) f[b]=v;
updt(v),updt(u);
}
void Splay(int x,int tar)
{
while (f[x]!=tar)
{
if (f[f[x]]!=tar)
{
if (which(f[x])==which(x)) Rotate(f[x]);
else Rotate(x);
}
Rotate(x);
}
if (!tar) root=x;
}
int find(int x)
{
int u=root,v=0;
while (u && val[u]!=x)
{
v=u;
if (x<val[u]) u=ls[u];
else u=rs[u];
}
return u?u:v;
}
void insert(int x)
{
int u=root,v=0;
while (u && val[u]!=x)
{
v=u;
if (x<=val[u]) u=ls[u];
else u=rs[u];
}
if (u && val[u]==x)
return (void)(cnt[u]++,sze[u]++,Splay(u,0));
f[++idx]=v;
sze[idx]=1;
val[idx]=x;
if (v) x<=val[v]?ls[v]=idx:rs[v]=idx;
Splay(idx,0);
}
int getmn(int x)
{
while (ls[x]) x=ls[x];
return x;
}
int getmx(int x)
{
while (rs[x]) x=rs[x];
return x;
}
void erase(int x)
{
int tmp=find(x);
Splay(tmp,0);
if (cnt[tmp]>1) cnt[tmp]--,sze[tmp]--;
else if (!ls[tmp] || !rs[tmp]) root=ls[tmp]+rs[tmp],f[root]=0;
else
{
f[ls[tmp]]=0;
int u=getmx(ls[tmp]);
Splay(u,0);
rs[u]=rs[tmp];
f[rs[tmp]]=u;
updt(u);
}
}
int getkth(int k)
{
int cur=root;
while (cur)
{
if (sze[ls[cur]]>=k) cur=ls[cur];
else if (sze[ls[cur]]+cnt[cur]>=k) return val[cur];
else k-=sze[ls[cur]]+cnt[cur],cur=rs[cur];
}
return val[cur];
}
int getrank(int x)
{
int cur=find(x);
Splay(cur,0);
return sze[ls[cur]]+1;
}
int getpre(int x)
{
int cur=find(x);
if (val[cur]<x) return val[cur];
Splay(cur,0);
return val[getmx(ls[cur])];
}
int getnxt(int x)
{
int cur=find(x);
if (val[cur]>x) return val[cur];
Splay(cur,0);
return val[getmn(rs[cur])];
}
int main()
{
n=read();
for (int i=1,op,a;i<=n;i++)
{
op=read();
a=read();
printf("%d %d\n",op,a);
if (op==1) insert(a);
if (op==2) erase(a);
if (op==3) printf("%d\n",getrank(a));
if (op==4) printf("%d\n",getkth(a));
if (op==5) printf("%d\n",getpre(a));
if (op==6) printf("%d\n",getnxt(a));
}
return 0;
}
转载于:https://www.cnblogs.com/mrha/p/8157645.html