平衡树模板题
http://codevs.cn/problem/4543/
Description
写一种数据结构,来维护一些数,其中需要提供以下操作:
1. 插入x数
2. 删除x数(若有多个相同的数,因只删除一个)
3. 查询x数的排名(若有多个相同的数,因输出最小的排名)
4. 查询排名为x的数
5. 求x的前驱(前驱定义为小于x,且最大的数)
6. 求x的后继(后继定义为大于x,且最小的数)
Input Description
第一行为n,表示操作的个数,下面n行每行有两个数opt和x,opt表示操作的序号(1<=opt<=6)
Output Description
对于操作3,4,5,6每行输出一个数,表示对应答案
Sample Input
10
1 10
1 10
1 10
1 10
1 10
1 10
1 10
1 10
1 10
1 10
Sample Output
EOF(无输出)
Hint
n <= 100000;
1 <= opt <= 6;
abs(x) <= 2000000000;
#include <cstdio>
#include <cstdlib>
#define N 100005
using namespace std;
int l[N],r[N],tr[N],siz[N],num[N],rnd[N];
int n,u,k,m,opt,ans,root;
void update(int x){siz[x] = siz[l[x]] + siz[r[x]] + num[x];}
void rturn(int &x){
u = l[x],l[x] = r[u],r[u] = x;
siz[u] = siz[x],update(x),x = u;
}
void lturn(int &x){
u = r[x],r[x] = l[u],l[u] = x;
siz[u] = siz[x],update(x),x = u;
}
void ins(int &x){
if (!x) tr[x = ++m] = k,rnd[x] = rand();
siz[x]++;if (k == tr[x]) num[x]++;
if (k > tr[x]) {ins(r[x]);if (rnd[r[x]] > rnd[x]) lturn(x);return;}
if (k < tr[x]) {ins(l[x]);if (rnd[l[x]] > rnd[x]) rturn(x);return;}
}
void del(int &x){
if (k==tr[x]) {
if (num[x]>1) {num[x]--;return;}
if ((!l[x]) || (!r[x])) {x = l[x] + r[x];return;}
if (rnd[r[x]] > rnd[l[x]]) rturn(x);else lturn(x);
del(x);return;
}
siz[x]--;(k > tr[x]) ? del(r[x]) : del(l[x]);
}
int get_rnk(int x){
if (k == tr[x]) return siz[l[x]]+1;
if (k < tr[x]) return get_rnk(l[x]);
if (k > tr[x]) return siz[l[x]] + num[x] + get_rnk(r[x]);
}
int get_val(int x){
if (k <= siz[l[x]]) return get_val(l[x]);
k -= siz[l[x]];if (k <= num[x]) return tr[x];
k -= num[x];return get_val(r[x]);
}
int prev(int x){
if (!x) return ans;
if (k <= tr[x]) return prev(l[x]);
ans = tr[x];return prev(r[x]);
}
int next(int x){
if (!x) return ans;
if (k >= tr[x]) return next(r[x]);
ans = tr[x];return next(l[x]);
}
int main(){
scanf("%d",&n);
for (int i=1;i<=n;i++){
scanf("%d%d",&opt,&k);
if (opt==1) ins(root);
if (opt==2) del(root);
if (opt==3) printf("%d\n",get_rnk(root));
if (opt==4) printf("%d\n",get_val(root));
if (opt==5) printf("%d\n",prev(root));
if (opt==6) printf("%d\n",next(root));
}
return 0;
}