P3369普通平衡树:
此题使用Spaly会T,因为Splay要求严格平衡,即每个节点的左右子树之间高度相差不超过1,这个要求可以避免二叉树在查询时出现O(n)的情况,但也会因为过多的增加和删除操作导致复杂度上升。
(原本是会T的,但是改了一些细节后居然A了)
代码(修改后):
#include<cstdio>
#include<iostream>
#include<algorithm>
#include<cstring>
using namespace std;
const int N = 1e6+7;
inline int read(){
int ref=0,x=1;char ch=getchar();
while(!isdigit(ch)){if(ch=='-')x=-1;ch=getchar();}
while(isdigit(ch)) ref=ref*10+ch-'0',ch=getchar();
return ref*x;
}
int n,tot,rot,cnt[N],ch[N][2],fa[N],val[N],siz[N];
void rotate(int x){
int y=fa[x],z=fa[y];
int k=(ch[y][1]==x);
ch[z][ch[z][1]==y]=x;
fa[x]=z;
ch[y][k]=ch[x][k^1];
fa[ch[x][k^1]]=y;
ch[x][k^1]=y;
fa[y]=x;
siz[y]=siz[ch[y][0]]+siz[ch[y][1]]+cnt[y];
siz[x]=siz[ch[x][0]]+siz[ch[x][1]]+cnt[x];
}
void splay(int x,int to){
while(fa[x]!=to){
int y=fa[x],z=fa[y];
if(z!=to) ((ch[y][1]==x)==(ch[z][1]==y))?rotate(y):rotate(x);
//唯独在x,父亲和祖父三者在一条直线上时才用双旋,否则一律都是两次单旋
rotate(x);
}
if(to==0) rot=x;
}
void find(int x){//找到离x值最近的一个数(包括自己本身)
int u=rot;
if(!u) return ;
while(ch[u][x>val[u]]&&val[u]!=x) u=ch[u][x>val[u]];
splay(u,0);
}
void ins(int x){
int ff=0,u=rot;
while(val[u]!=x&&u) ff=u,u=ch[u][x>val[u]];
if(u) cnt[u]++;
else{
u=++tot;
if(ff==0) rot=u;
else ch[ff][x>val[ff]]=u;
//siz[ff]++;
//本处不用写siz[ff]++,因为这似乎会引发错误(原因未知)
//而splay操作会自底向上地更新siz
fa[u]=ff;
val[u]=x;
cnt[u]=siz[u]=1;
}
splay(u,0);
}
int findpre(int x){
find(x);
if(val[rot]<x) return rot;//特判 x不在树上的情况
int u=ch[rot][0];
if(!u) return -1;
while(ch[u][1]) u=ch[u][1];
return u;
}
int findnxt(int x){
find(x);
if(val[rot]>x) return rot;
int u=ch[rot][1];
if(!u) return -1;
while(ch[u][0]) u=ch[u][0];
return u;
}
void del(int x){
int pre=findpre(x),nxt=findnxt(x);
splay(pre,0),splay(nxt,pre);
int u=ch[nxt][0];
if(cnt[u]>1) cnt[u]--,splay(u,0);
else ch[nxt][0]=0;
}
int findk(int x){
int u=rot;
if(siz[u]<x) return -1;
while(1){
if(x>siz[ch[u][0]]&&x<=siz[ch[u][0]]+cnt[u]) return u;
else if(x<=siz[ch[u][0]]) u=ch[u][0];
else x-=(siz[ch[u][0]]+cnt[u]),u=ch[u][1];
}
}
int main()
{
ins(-2147483640),ins(2147483640);//避免出现最小数没有前驱,最大数没有后继
cin>>n;
for(int i=1;i<=n;++i){
int op=read(),x=read();
if(op==1) ins(x);
if(op==2) del(x);
if(op==3) find(x),printf("%d\n",siz[ch[rot][0]]);
if(op==4) printf("%d\n",val[findk(x+1)]);
if(op==5) printf("%d\n",val[findpre(x)]);
if(op==6) printf("%d\n",val[findnxt(x)]);
}
return 0;
}