题目:https://www.luogu.org/problemnew/show/P3369
#include<bits/stdc++.h>
//#define DEBUG
using namespace std;
const int inf = 2e9 + 50;
class node
{
public:
node* ch[2], *fa;
int val;
int size;
int recy;
node(int x)
{
ch[0] = ch[1] = fa = NULL;
val = x;size = 1;recy = 1;
}
};
inline void push_up(node *v)
{
v->size = v->recy + (v->ch[0] ? v->ch[0]->size : 0) + (v->ch[1] ? v->ch[1]->size : 0);
}
inline void attach(node *p, node *s, int x)
{
p->ch[x] = s;
if(s) s->fa = p;
}
class Splay//存储规则:小左大右,重复节点记录
{
node* root;
void rotate(node *v)
{
if(v == root) return ;
node* p = v->fa;
int flag = p->ch[1] == v;
if(p->fa) attach(p->fa, v, p->fa->ch[1] == p);
else v->fa = NULL, root = v;
attach(p, v->ch[flag ^ 1], flag);
attach(v, p, flag ^ 1);
push_up(p);
push_up(v);
}
public:
void init()
{
root = NULL;
}
node* GetRoot()
{
return root;
}
node* splay(node *v)
{
for(node *p; p = v->fa; rotate(v))
{
if(p->fa) rotate((p->fa->ch[0] == p) == (p->ch[0] == v) ? p : v);
}
return root = v;
}
node* search(int x)//查找值为x的节点 没找到返回值最接近的节点
{
node* p = root;
while(p)
{
if(p->val == x) break;
if(p->ch[p->val < x]) p = p->ch[p->val < x];
else break;
}
splay(p);
return p;
}
node* insert(int x)//插入一个值为x的节点
{
if(!root) return root = new node(x);
node* p = search(x);
if(p->val == x)
{
p->recy++, p->size++;
return p;
}
node* v = new node(x);
int flag = p->val > x;
attach(v, p, flag);
attach(v, p->ch[flag ^ 1], flag ^ 1);
p->ch[flag ^ 1] = NULL;
v->fa = NULL;
push_up(p);push_up(v);
return root = v;
}
bool erase(int x) //删除值为x的节点
{
node* p = search(x);
if(!p || p->val != x)
return false;
if(p->recy > 1)
p->recy--, p->size--;
else if(!p->ch[0] && !p->ch[1])
root = NULL;
else if(!p->ch[0])
{
root = p->ch[1];
root->fa = NULL;
delete p;
}
else if(!p->ch[1])
{
root = p->ch[0];
root->fa = NULL;
delete p;
}
else
{
node* tmp = root->ch[0];
tmp->fa = NULL;
root->ch[0] = NULL;
root = root->ch[1];
root->fa = NULL;
search(p->val);
root->ch[0] = tmp;
tmp->fa = root;
delete p;
}
if(root) push_up(root);
return true;
}
int rank(int x)//返回x的排名 从1开始,重复按第一个算
{
search(x);
int res = root->ch[0] ? root->ch[0]->size + 1: 1;
return res + (root->val < x ? root->recy : 0);
}
int arank(int x)//查询排名x的值
{
if(x <= 0) return -inf;
node* p = root;
while(p)
{
if(p->ch[0] && p->ch[0]->size >= x)
p = p->ch[0];
else if((p->ch[0] ? p->ch[0]->size : 0) + p->recy >= x)
{
splay(p);
return p->val;
}
else
{
x -= (p->ch[0] ? p->ch[0]->size : 0) + p->recy;
p = p->ch[1];
}
}
return inf;
}
int pre(int x)//求x的前驱(前驱定义为小于x,且最大的数)
{
search(x);
node* p = root;
if(x > p->val) return p->val;
if(p->ch[0])
{
p = p->ch[0];
int res = p->val;
while(p->ch[1])
{
p = p->ch[1];
res = p->val;
}
return res;
}
return -inf;
}
int suc(int x)//求x的后继(后继定义为大于x,且最小的数)
{
search(x);
node* p = root;
if(p->val > x) return p->val;
if(p->ch[1])
{
p = p->ch[1];
int res = p->val;
while(p->ch[0])
{
p = p->ch[0];
res = p->val;
}
return res;
}
return inf;
}
}F;
int main()
{
ios_base::sync_with_stdio(0);
cin.tie(0); cout.tie(0);
int t, flag, x;
F.init();
cin >> t;
while(t--)
{
cin >> flag >> x;
switch(flag)
{
case 1: F.insert(x); break;
case 2: F.erase(x); break;
case 3: cout << F.rank(x) << endl; break;
case 4: cout << F.arank(x) << endl; break;
case 5: cout << F.pre(x) << endl; break;
case 6: cout << F.suc(x) << endl; break;
}
}
return 0;
}