板子题哇。
splay
#include <bits/stdc++.h>
using namespace std;
#define ll long long
#define inf 0x3f3f3f3f
#define N 100010
inline int read(){
int x=0,f=1;char ch=getchar();
while(ch<'0'||ch>'9'){if(ch=='-')f=-1;ch=getchar();}
while(ch>='0'&&ch<='9') x=x*10+ch-'0',ch=getchar();
return x*f;
}
int sz[N],fa[N],c[N][2],num[N],val[N],tot=0,rt=0,ans=0;
inline void update(int p){
int l=c[p][0],r=c[p][1];
sz[p]=sz[l]+sz[r]+num[p];
}
inline int find(int p,int x){
if(val[p]==x) return p;
if(x<val[p]) return find(c[p][0],x);
else return find(c[p][1],x);
}
inline void rotate(int x,int &k){
int y=fa[x],z=fa[y],l= x==c[y][1],r=l^1;
if(y==k) k=x;
else c[z][y==c[z][1]]=x;
fa[c[x][r]]=y;fa[y]=x;fa[x]=z;
c[y][l]=c[x][r];c[x][r]=y;update(y);update(x);
}
inline void splay(int x,int &k){
while(x!=k){
int y=fa[x],z=fa[y];
if(y!=k){
if(x==c[y][1] ^ y==c[z][1]) rotate(x,k);
else rotate(y,k);
}rotate(x,k);
}
}
inline void ins(int &p,int x,int f){
if(!p){p=++tot;sz[p]=num[p]=1;fa[p]=f;val[p]=x;splay(p,rt);return;}
if(val[p]==x){num[p]++;splay(p,rt);return;}
if(x<val[p]) ins(c[p][0],x,p);else ins(c[p][1],x,p);update(p);
}
inline int getrk(int p,int x){//查排名
if(val[p]==x){splay(p,rt);return sz[c[p][0]]+1;}
if(x<val[p]) return getrk(c[p][0],x);
else return getrk(c[p][1],x);
}
inline int getkth(int p,int x){//查第k个数
if(x<=sz[c[p][0]]) return getkth(c[p][0],x);x-=sz[c[p][0]];
if(x>=1&&x<=num[p]){splay(p,rt);return val[p];}x-=num[p];
return getkth(c[p][1],x);
}
inline void pre(int p,int x){
if(!p) return;
if(val[p]<x){ans=val[p];pre(c[p][1],x);}
else pre(c[p][0],x);
}
inline void succ(int p,int x){
if(!p) return;
if(val[p]>x){ans=val[p];succ(c[p][0],x);}
else succ(c[p][1],x);
}
inline void del(int xx){//注意跟的问题,否则会T
int x=find(rt,xx),y=fa[x];if(--num[x]){splay(x,rt);return;}
if(c[x][0]*c[x][1]==0){
if(!y){rt=c[x][0]+c[x][1];return;}
c[y][val[x]>=val[y]]=c[x][0]+c[x][1],fa[c[x][0]+c[x][1]]=y;
splay(y,rt);return;
}else{
splay(x,rt);int prex=c[x][0];while(c[prex][1]) prex=c[prex][1];
int succx=c[x][1];while(c[succx][0]) succx=c[succx][0];
splay(prex,rt);splay(succx,c[prex][1]);
c[succx][0]=0;update(succx);update(prex);
}
}
int main(){
// freopen("a.in","r",stdin);
int owo=read();while(owo--){
int op=read(),x=read();
if(op==1) ins(rt,x,0);
if(op==2) del(x);
if(op==3) printf("%d\n",getrk(rt,x));
if(op==4) printf("%d\n",getkth(rt,x));
if(op==5) {pre(rt,x);printf("%d\n",ans);}
if(op==6) {succ(rt,x);printf("%d\n",ans);}
}return 0;
}
treap
#include<cstdio>
#include<cstring>
#include<cstdlib>
#include<ctime>
using namespace std;
int const N=100010;
int n,num=0,root=0,ans=0;
struct node{
int x,l,r,s1,s2,w;
}tree[N];
void update(int p){
tree[p].s2=tree[p].s1;
if(tree[p].l) tree[p].s2+=tree[tree[p].l].s2;
if(tree[p].r) tree[p].s2+=tree[tree[p].r].s2;
}
void leftrotate(int &p){
int t=tree[p].r;
tree[p].r=tree[t].l;
tree[t].l=p;
tree[t].s2=tree[p].s2;
update(p);
p=t;
}
void rightrotate(int &p){
int t=tree[p].l;
tree[p].l=tree[t].r;
tree[t].r=p;
tree[t].s2=tree[p].s2;
update(p);
p=t;
}
void insert1(int &p,int x){
if(p==0){
tree[++num].x=x;
tree[num].s1=1;
tree[num].s2=1;
tree[num].w=rand();
tree[num].l=0;
tree[num].r=0;
p=num;
return;
}
tree[p].s2++;
if(tree[p].x==x) tree[p].s1++;
else if(x<tree[p].x){
insert1(tree[p].l,x);
if(tree[tree[p].l].w<tree[p].w) rightrotate(p);
}
else{
insert1(tree[p].r,x);
if(tree[tree[p].r].w<tree[p].w) leftrotate(p);
}
}
void delete1(int &p,int x){
if(p==0) return;
if(tree[p].x==x){
if(tree[p].s1>1){
tree[p].s1--;tree[p].s2--;
}
else if(tree[p].l*tree[p].r==0){
p=tree[p].l+tree[p].r;
}
else if(tree[tree[p].l].w<tree[tree[p].r].w){//找到后只是旋转,就不要再更新size了
rightrotate(p);
delete1(p,x);
}
else{
leftrotate(p);
delete1(p,x);
}
}
else if(x<tree[p].x){
tree[p].s2--;delete1(tree[p].l,x);
}
else {
tree[p].s2--;delete1(tree[p].r,x);
}
}
//查x的排名
int query1(int p,int x){
if(p==0) return 0;
if(tree[p].x==x) return tree[tree[p].l].s2+1;
if(tree[p].x<x) return tree[tree[p].l].s2+tree[p].s1+query1(tree[p].r,x);
if(tree[p].x>x) return query1(tree[p].l,x);
}
//查Kth
int query2(int p,int x){
if(p==0) return 0;
if(tree[tree[p].l].s2+1<=x&&tree[tree[p].l].s2+tree[p].s1>=x) return tree[p].x;
if(tree[tree[p].l].s2+1>x) return query2(tree[p].l,x);
else return query2(tree[p].r,x-tree[p].s1-tree[tree[p].l].s2);
}
void pread1(int p,int x){
if(p==0) return;
if(tree[p].x<x){
ans=tree[p].x;
pread1(tree[p].r,x);
}
else pread1(tree[p].l,x);
}
void succ1(int p,int x){
if(p==0) return;
if(tree[p].x>x){
ans=tree[p].x;
succ1(tree[p].l,x);
}
else succ1(tree[p].r,x);
}
int main(){
srand(time(0));
freopen("input4.in","r",stdin);
freopen("ans.out","w",stdout);
scanf("%d",&n);
while(n--){
int x,opt;
scanf("%d%d",&opt,&x);
if(opt==1) insert1(root,x);
if(opt==2) delete1(root,x);
if(opt==3) printf("%d\n",query1(root,x));
if(opt==4) printf("%d\n",query2(root,x));
if(opt==5){
ans=0;
pread1(root,x);
printf("%d\n",ans);
}
if(opt==6){
ans=0;
succ1(root,x);
printf("%d\n",ans);
}
}
return 0;
}