写一个数据结构来提供以下操作:
- 插入 x x x
- 删除 x x x (多个相同则只删除一个)
- 查询 x x x 的排名(从 1 1 1 开始数)
- 查询排名为 x x x 的数
- 求 x x x 的前驱 (前驱定义为 max { k ∣ k < x } \max\{k|k<x\} max{k∣k<x})
- 求 x x x 的后继 (后继定义为 min { k ∣ k > x } \min\{k|k>x\} min{k∣k>x})
本质就是生成一个平衡树。
用 Treap 或者 Splay
- Treap
参考 oi-wiki 以及 https://www.luogu.com.cn/blog/Chanis/fhq-treap
这个词是 Tree 和 Heap 的组合,Treap 也正是由树和堆组合形成的数据结构,每个结点存储一个优先级 priority ,需要满足父结点的 priority 值要 ≥ \ge ≥ 两个儿子的 priority,该优先级值是在每个结点建立时随机生成的,因此 treap 是一种弱平衡二叉树(期望平衡)
实现分为旋转式和无旋式。
无旋式Treap:又称为分裂合并 treap,又又称 fhq treap,核心操作只有分裂和合并。
代码如下:
#include<iostream>
#include<cstdio>
#include<cstdlib>
#define MAXN 500005
using namespace std;
int sz[MAXN],ch[MAXN][2],dat[MAXN],val[MAXN];
int T,cnt,n,m,x,y,z,k,a,root;
int newnode(int x){
sz[++cnt]=1;val[cnt]=x;dat[cnt]=rand();return cnt;
}
void pushup(int x){sz[x]=1+sz[ch[x][0]]+sz[ch[x][1]];}
void split(int cur,int k,int &x,int &y){
if(!cur)x=y=0;
else{
if(val[cur]<=k)x=cur,split(ch[cur][1],k,ch[cur][1],y);
else y=cur,split(ch[cur][0],k,x,ch[cur][0]);
pushup(cur);
}
}
int merge(int x,int y){
if(!x||!y)return x+y;
if(dat[x]>dat[y]){ch[x][1]=merge(ch[x][1],y);pushup(x);return x;}
else {ch[y][0]=merge(x,ch[y][0]);pushup(y);return y;}
}
int kth(int cur,int k){
while(true){
if(k<=sz[ch[cur][0]])cur=ch[cur][0];
else if(k==sz[ch[cur][0]]+1)return cur;
else k-=sz[ch[cur][0]]+1,cur=ch[cur][1];
}
}
int main(){
#ifdef WINE
freopen("data.in","r",stdin);
#endif
scanf("%d",&T);
while(T--){
scanf("%d%d",&k,&a);
if(k==1){
split(root,a,x,y);
root=merge(merge(x,newnode(a)),y);
}else if(k==2){
split(root,a,x,z);
split(x,a-1,x,y);
y=merge(ch[y][0],ch[y][1]);
root=merge(merge(x,y),z);
}else if(k==3){
split(root,a-1,x,y);
printf("%d\n",sz[x]+1);
root=merge(x,y);
}else if(k==4)printf("%d\n",val[kth(root,a)]);
else if(k==5){
split(root,a-1,x,y);
printf("%d\n",val[kth(x,sz[x])]);
root=merge(x,y);
}else{
split(root,a,x,y);
printf("%d\n",val[kth(y,1)]);
root=merge(x,y);
}
}
return 0;
}
旋转式 Treap
代码如下:
#include<iostream>
#include<cstdio>
#include<cstdlib>
#define INF 0x3f3f3f3f
#define MAXN 1000020
using namespace std;
int n,ch[MAXN][2],val[MAXN],dat[MAXN],sz[MAXN],cnt[MAXN];
int tot,root,x,k;
int newnode(int v){
val[++tot]=v;dat[tot]=rand();
sz[tot]=1;cnt[tot]=1;
return tot;
}
void pushup(int i){
sz[i]=sz[ch[i][0]]+sz[ch[i][1]]+cnt[i];
}
void build(){
root=newnode(-INF);ch[root][1]=newnode(INF);
pushup(root);
}
void rotate(int &k,int d){
int tmp=ch[k][d^1];
ch[k][d^1]=ch[tmp][d];
ch[tmp][d]=k;
k=tmp;
pushup(ch[k][d]),pushup(k);
}
void insert(int &k,int v){
if(!k){k=newnode(v);return;}
if(v==val[k])cnt[k]++;
else{
int d=v<val[k]?0:1;
insert(ch[k][d],v);
if(dat[k]<dat[ch[k][d]])rotate(k,d^1);
}
pushup(k);
}
void remove(int &k,int v){
if(!k)return;
if(v==val[k]){
if(cnt[k]>1){cnt[k]--,pushup(k);return;}
if(ch[k][0]||ch[k][1]){
if(!ch[k][1]||dat[ch[k][0]]>dat[ch[k][1]])
rotate(k,1),remove(ch[k][1],v);
else rotate(k,0),remove(ch[k][0],v);
pushup(k);
}
else k=0;
return ;
}
v<val[k]?remove(ch[k][0],v):remove(ch[k][1],v);
pushup(k);
}
int getrank(int k,int v){
if(!k)return 0;
if(v==val[k])return sz[ch[k][0]]+1;
else if(v<val[k])return getrank(ch[k][0],v);
else return sz[ch[k][0]]+cnt[k]+getrank(ch[k][1],v);
}
int getval(int k,int rk){
if(!k)return INF;
if(rk<=sz[ch[k][0]])return getval(ch[k][0],rk);
else if(rk<=sz[ch[k][0]]+cnt[k])return val[k];
else return getval(ch[k][1],rk-sz[ch[k][0]]-cnt[k]);
}
int getpre(int v){
int k=root,pre;
while(k){
if(val[k]<v)pre=val[k],k=ch[k][1];
else k=ch[k][0];
}
return pre;
}
int getnxt(int v){
int k=root,nxt;
while(k){
if(val[k]>v)nxt=val[k],k=ch[k][0];
else k=ch[k][1];
}
return nxt;
}
int main(){
#ifdef WINE
freopen("data.in","r",stdin);
#endif
build();
scanf("%d",&n);
for(int i=1;i<=n;i++){
scanf("%d%d",&k,&x);
if(k==1)insert(root,x);
if(k==2)remove(root,x);
if(k==3)printf("%d\n",getrank(root,x)-1);
if(k==4)printf("%d\n",getval(root,x+1));
if(k==5)printf("%d\n",getpre(x));
if(k==6)printf("%d\n",getnxt(x));
}
return 0;
}
- Splay
Splay Tree(伸展树)是一种二叉排序树,能够在 O ( log n ) O(\log n) O(logn) 时间内完成插入、查找和删除。
参考 https://blog.csdn.net/qq_33184171/article/details/70304674
代码如下:
#include<iostream>
#include<cstdio>
#define MAXN 200010
#define INF 0x3f3f3f3f
using namespace std;
int tot,root,k,x,ch[MAXN][2],f[MAXN],sz[MAXN],val[MAXN];
int cnt[MAXN],T;
void newnode(int k,int v,int fa){
f[k]=fa;val[k]=v;sz[k]=cnt[k]=1;
ch[k][0]=ch[k][1]=0;
}
void pushup(int x){
if(x)sz[x]=sz[ch[x][0]]+sz[ch[x][1]]+cnt[x];
}
int search(int k,int x){
if(ch[k][0]&&val[k]>x)return search(ch[k][0],x);
if(ch[k][1]&&val[k]<x)return search(ch[k][1],x);
return k;
}
void rotate(int x,int k){// 0:left, 1:right
int y=f[x],z=f[y];
ch[y][k^1]=ch[x][k];if(ch[x][k])f[ch[x][k]]=y;
f[x]=z;if(z)ch[z][ch[z][1]==y]=x;
f[y]=x;ch[x][k]=y;
pushup(y);pushup(x);
}
void splay(int x,int t){
for(int i=f[x];f[x]!=t;i=f[x])
rotate(x,(ch[i][0]==x));
if(t==0)root=x;
}
int extreme(int x,int k){// k=0,1: min,max
while(ch[x][k])x=ch[x][k];splay(x,0);
return x;
}
int pre(int x){
int k=search(root,x);splay(k,0);
if(val[k]<x)return k;
return extreme(ch[k][0],1);
}
int nxt(int x){
int k=search(root,x);splay(k,0);
if(val[k]>x)return k;
return extreme(ch[k][1],0);
}
void insert(int x){
int y=search(root,x),k=-1;
if(val[y]==x){
cnt[y]++;sz[y]++;
for(int i=y;i;i=f[i])pushup(i);
}else{
int p=pre(x),s=nxt(x);
splay(p,0);splay(s,p);
newnode(++tot,x,ch[root][1]);
ch[ch[root][1]][0]=tot;
for(int i=ch[root][1];i;i=f[i])pushup(i);
}
if(k==-1)splay(y,0);else splay(tot,0);
}
void remove(int x){
int y=search(root,x);
if(val[y]!=x)return;
if(cnt[y]>1){
cnt[y]--;sz[k]--;
for(int i=y;i;i=f[i])pushup(i);
}else if(!ch[y][0]||!ch[y][1]){
int z=f[y];
ch[z][ch[z][1]==y]=ch[y][ch[y][0]==0];
f[ch[y][ch[y][0]==0]]=z;
for(int i=z;i;i=f[i])pushup(i);
}else{
int p=pre(x),s=nxt(x);
splay(p,0);splay(s,p);
ch[ch[root][1]][0]=0;
for(int i=s;i;i=f[i])pushup(i);
}
}
int rk(int x){
int k=search(root,x);splay(k,0);
return sz[ch[root][0]]+1;
}
int kth(int x,int k){
if(sz[ch[x][0]]+1<=k&&k<=sz[ch[x][0]]+cnt[x])return x;
if(sz[ch[x][0]]>=k)return kth(ch[x][0],k);
return kth(ch[x][1],k-sz[ch[x][0]]-cnt[x]);
}
int main(){
#ifdef WINE
freopen("data.in","r",stdin);
#endif
root=1;
newnode(++tot,-INF,0);newnode(++tot,INF,root);
ch[root][1]=tot;
scanf("%d",&T);
while(T--){
scanf("%d%d",&k,&x);
if(k==1)insert(x);
if(k==2)remove(x);
if(k==3)printf("%d\n",rk(x)-1);
if(k==4)printf("%d\n",val[kth(root,x+1)]);
if(k==5)printf("%d\n",val[pre(x)]);
if(k==6)printf("%d\n",val[nxt(x)]);
}
return 0;
}