题目来源:BZOJ 3224
思路:
用来练习splay,通过反复的写,有几点需要注意:
void rotate(node *now){
node *fa = now->pre, *gra = now->pre->pre;
int wh = now->wh();
fa->set_ch(wh, now->ch[wh^1]);
now->set_ch(wh^1, fa);
now->pre = gra;//要记得跟新前驱。
if(gra != null) gra->ch[gra->ch[0] == fa ? 0 : 1] = now;
}
void splay(node *now, node *tar){
for( ; now->pre != tar; rotate(now))
if(now->pre->pre != tar)
now->wh() == now->pre->wh() ? rotate(now->pre) : rotate(now);
if(tar == null) root = now;//不要忘记更新根节点。
}
node* find(int val){
node *now = root;
while(now != null){
if(now->val == val) break; //一定要先退出去,不急着返回。
if(val < now->val) now = now->ch[0];
else now = now->ch[1];
}
if(now != null) splay(now, null);
return now;
}
//在删除的时候如果有一个以上,直接减去一个退出。
if(now->cnt > 1){now->cnt --; now->size --; return;}
int get_rank(int val){
node *now = find(val);
if(now == null) return -1;
return now->ch[0]->size + 1;//这里加1就可以,相同的按最低的名次。
}
不要把一些变量搞混。
代码:
#include <cstdio>
#include <iostream>
const int maxn = 1000010;
const int inf = 2e9;
struct node{
int val, cnt, size;
node *ch[2], *pre;
int wh(){return pre->ch[0] == this ? 0 : 1;};
void update(){size = ch[0]->size + ch[1]->size + cnt;}
void set_ch(int wh, node *child);
} Pool[maxn], *root, *null;
void node::set_ch(int wh, node *child){
ch[wh] = child;
if(child != null) child->pre = this;
update();
}
int cnt;
node *one(int val){
node *one = &Pool[++cnt];
one->val = val, one->size = one->cnt = 1;
one->pre = one->ch[0] = one->ch[1] = null;
return one;
}
void rotate(node *now){
node *fa = now->pre, *gra = now->pre->pre;
int wh = now->wh();
fa->set_ch(wh, now->ch[wh^1]);
now->set_ch(wh^1, fa);
now->pre = gra;
if(gra != null) gra->ch[gra->ch[0] == fa ? 0 : 1] = now;
}
void splay(node *now, node *tar){
for( ; now->pre != tar; rotate(now))
if(now->pre->pre != tar)
now->wh() == now->pre->wh() ? rotate(now->pre) : rotate(now);
if(tar == null) root = now;
}
void insert(int val){
node *now = root, *last = null;
while(now != null){
last = now;
if(now->val == val){
now->size ++;
now->cnt ++;
splay(now, null);
return;
}
if(val < now->val) now = now->ch[0];
else now = now->ch[1];
}
now = one(val);
if(last == null) root = now;
else{
if(val < last->val) last->set_ch(0, now);
else last->set_ch(1, now);
splay(now, null);
}
}
node* find(int val){
node *now = root;
while(now != null){
if(now->val == val) break;
if(val < now->val) now = now->ch[0];
else now = now->ch[1];
}
if(now != null) splay(now, null);
return now;
}
void del(int val){
node *now = find(val);
if(now == null) return;
if(now->cnt > 1){now->cnt --; now->size --; return;}
if(now->ch[0]==null && now->ch[1]==null) root = null;
else if(now->ch[0] == null) root = now->ch[1], now->ch[1]->pre = null;
else if(now->ch[1] == null) root = now->ch[0], now->ch[0]->pre = null;
else{
node *t = now->ch[0];
while(t->ch[1] != null) t = t->ch[1];
splay(t, now);
t->set_ch(1, now->ch[1]);
t->pre = null, root = t;
}
}
int get_rank(int val){
node *now = find(val);
if(now == null) return -1;
return now->ch[0]->size + 1;
}
int kth(int k){
node *now = root;
int left = k;
while(now != null){
if(left <= now->ch[0]->size+now->cnt && left >= now->ch[0]->size+1){
splay(now, null);
return now->val;
}else if(left <= now->ch[0]->size) now = now->ch[0];
else{
left -= now->ch[0]->size+now->cnt;
now = now->ch[1];
}
}
return -inf;
}
int pre(int val){
int ans = -inf;
node *now = root;
while(now != null){
if(val <= now->val){
now = now->ch[0];
}else if(val > now->val){
ans = std::max(ans, now->val);
now = now->ch[1];
}
}
return ans;
}
int nxt(int val){
int ans = inf;
node *now = root;
while(now != null){
if(val >= now->val){
now = now->ch[1];
}else if(val < now->val){
ans = std::min(ans, now->val);
now = now->ch[0];
}
}
return ans;
}
int main(){
null = &Pool[0];
null->val = 0, null->size = 0, null->cnt = 0;
null->pre = null->ch[0] = null->ch[1] = null;
root = null;
int q; scanf("%d", &q);
while(q--){
int order, val;
scanf("%d%d", &order, &val);
switch(order){
case 1: insert(val); break;
case 2: del(val); break;
case 3: printf("%d\n", get_rank(val)); break;
case 4: printf("%d\n", kth(val)); break;
case 5: printf("%d\n", pre(val)); break;
case 6: printf("%d\n", nxt(val)); break;
}
}
return 0;
}