#include <cstdio>
#include <cstring>
#include <cstdlib>
using namespace std;
struct Node
{
Node *ch[2];
int r;
int v;
int s;
Node(int v): v(v) {
ch[0] = ch[1] = NULL; r = rand(); s = 1;
}
bool operator < (const Node& rhs) const{
return r < rhs.r;
}
int cmp(int x) const{
if(x == v) return -1;
return x < v ? 0 : 1;
}
void maintain(){
s = 1;
if(ch[0] != NULL) s += ch[0]->s;
if(ch[1] != NULL) s += ch[1]->s;
}
};
void rotate(Node* &o, int d){
Node* k = o->ch[d^1]; o->ch[d^1] = k->ch[d]; k->ch[d] = o;
o->maintain(); k->maintain(); o = k;
}
void insert(Node* &o, int x){
if(o == NULL){
o = new Node(x);
}
else{
int d = (x < o->v ? 0 : 1);
insert(o->ch[d], x);
if(o->ch[d] > o) rotate(o, d^1);
}
o->maintain();
//printf("--------+%d\n", o->s);
}
void remove(Node* &o, int x){
int d = o->cmp(x);
if(d == -1){
Node* u = o;
if(o->ch[0] != NULL && o->ch[1] != NULL){
int d2 = o->ch[0] > o->ch[1] ? 1 : 0;
rotate(o, d2); remove(o->ch[d2], x);
}
else{
if(o->ch[0] == NULL) o = o->ch[1];
else o = o->ch[0];
delete u;
}
}
else
remove(o->ch[d], x);
if(o != NULL) o->maintain();
}
静态
#include <cstdio>
#include <cstring>
#include <cstdlib>
#include <algorithm>
using namespace std;
const int maxm = 100010;
int ch[maxm][2], r[maxm], val[maxm], sum[maxm], num[maxm], cnt, root;
void Node(int &rt, int x){
rt = ++cnt;
ch[rt][0] = ch[rt][1] = 0;
r[rt] = rand();
val[rt] = x;
sum[rt] = 1;
num[rt] = 1;
}
void maintain(int rt){
sum[rt] = sum[ch[rt][0]]+sum[ch[rt][1]]+num[rt];
}
void init()
{
ch[0][0] = ch[0][1] = 0;
r[0] = (1LL<<31)-1;
val[0] = 0;
sum[0] = 0;
cnt = 0;
root = 0;
Node(root, 2000000001);
}
void rotate(int &rt, int d){
int k = ch[rt][d^1]; ch[rt][d^1] = ch[k][d]; ch[k][d] = rt;
maintain(rt); maintain(k); rt = k;
}
void insert(int &rt, int x){
if(!rt){
Node(rt, x);
return;
}
else{
if(x == val[rt])
num[rt]++;
else
{
int d = x < val[rt] ? 0 : 1;
insert(ch[rt][d], x);
if(r[ch[rt][d]] < r[rt]) rotate(rt, d^1);
}
}
maintain(rt);
}
/*void remove(int &rt, int x){
if(val[rt] == x){
val[rt]--;
if(!val[rt]){
if(!ch[rt][0] && !ch[rt][1])
{
rt = 0;
return;
}
else{
int d = r[ch[rt][0]] > r[ch[rt][1]] ? 1 : 0;
rotate(rt, d);
remove(ch[rt][d], x);
}
else{
}
}
}
else
remove(ch[rt][x>val[rt]], x);
maintain(rt);
}*/
int kth(int rt, int k){
if(sum[ch[rt][0]] >= k)
return kth(ch[rt][0], k);
if(sum[ch[rt][0]]+num[rt] >= k)
return val[rt];
return kth(ch[rt][1], k-sum[ch[rt][0]]-num[rt]);
}