特别谢明:zky学长OuO
Blog传送门
#include<cstdio>
#include<climits>
#include<cstdlib>
#include<iostream>
using namespace std;
#define INF INT_MAX/3*2
struct SplayTree{
struct Node{
Node *ch[2];
int r,v,s,size;
Node(int v,Node* nl):v(v){size=s=1;r=rand();ch[0]=ch[1]=nl;}
void maintain(){size=ch[0]->size+s+ch[1]->size;}
}*root,*null;
SplayTree(){
null=new Node(0,0);null->r=INT_MAX;null->size=null->s=0;
null->ch[0]=null->ch[1]=null;root=null;
}
void rotate(Node* &o,int d) {
Node* k=o->ch[d^1];o->ch[d^1]=k->ch[d],k->ch[d]=o;
o->maintain(),k->maintain();o=k;
}
void insert(Node* &o,int x){
if(o==null){o=new Node(x,null);return;}
if(o->v==x){o->s++,o->size++;return;}
insert(o->ch[x>o->v],x);
if(o->ch[x>o->v]->r<o->r) rotate(o,x<o->v);
else o->maintain();
}
void remove(Node* &o,int x){
if(o->v==x){
if(o->s>1){o->s--;o->size--;return;}
int d=o->ch[0]->r>o->ch[1]->r;
if(o->ch[d]==null){delete o;o=null;return;}
rotate(o,d);remove(o->ch[d^1],x);
}
else remove(o->ch[x>o->v],x);
o->maintain();
}
int kth(Node *o,int x){
int d=o->ch[0]->size;
if(x<=d) return kth(o->ch[0],x);
else if(x>d+o->s) return kth(o->ch[1],x-d-o->s);
return o->v;
}
int rank(Node *o,int x){
int d=o->ch[0]->size;
if(x<o->v) return rank(o->ch[0],x);
else if(x>o->v) return rank(o->ch[1],x)+d+o->s;
return d;
}
int prev(Node *o,int x){
if(o==null) return -INF;
if(x==o->v&&o->s>1) return o->v;
else if(x<=o->v) return prev(o->ch[0],x);
return max(o->v,prev(o->ch[1],x));
}
int next(Node* o,int x){
if(o==null) return INF;
if(x==o->v&&o->s>1) return o->v;
else if(x>=o->v) return next(o->ch[1],x);
return min(o->v,next(o->ch[0],x));
}
int cmprk(Node *o,int k){
int d;
if((o->ch[0]->size)<k&&(o->ch[0]->size+o->s)>=k) d=-1;
else if(o->ch[0]->size>=k) d=0;else d=1;
return d;
}
/* void splay(Node* &o,int x){//单旋
int d=o->ch[0]==null?0:o->ch[0]->size;
if(d<x){
if(d+o->s<x){splay(o->ch[1],x-d-o->s);rotate(o,0);}
else return;
}
else{splay(o->ch[0],x);rotate(o,1);}
}*/
void splay(Node* &o,int k) {//双旋
if(o==null) return;
int d=cmprk(o,k);
if(d==1) k-=o->ch[0]->size+o->s;
if(d!=-1&&o->ch[d]!=null){
Node *p=o->ch[d];
int d2=cmprk(o->ch[d],k);
if(d2!=-1&&p->ch[d2]!=null){
int k2=(d2==0?k:k-(p->ch[d2]->size+p->s));
splay(p->ch[d2],k2);
if(d==d2) rotate(o,d^1);else rotate(o->ch[d],d);
}
rotate(o,d^1);
}
}
Node *merge(Node *left,Node *right){//合并
splay(left,left->s);left->ch[1]=right;
left->maintain();return left;
}
void split(Node *o,int k,Node* &left,Node* &right){//分裂 前K小在left中其余在right中
splay(o,k);left=o;right=o->ch[1];
o->ch[1]=null;left->maintain();
}
}T;