不用手写的情况
如果题目中仅仅涉及到以下几种操作,我们可以用 s e t set set代替;
括号为 s e t set set对应的操作
- 插入 ( i n s e r t ) (insert) (insert)
- 删除 ( e r a s e ) (erase) (erase)
- 找中序遍历的前驱/后继 ( − − / + + ) (--/++) (−−/++)
- 找最小/最大 ( b e g i n / e n d − 1 ) (begin/end-1) (begin/end−1)
注意这里的前驱和后继是树中存在的
有些题目会让我们求大于 x x x的最小值以及小于 x x x的最大值,这些值树中是不存在的,就需要我们手写平衡树了;
引入
T r e a p = B S T + H e a p Treap = BST +Heap Treap=BST+Heap
其中体现 B S T BST BST的部分在于根据 k e y key key进行查找;
体现 H e a p Heap Heap的部分在于随机生成 v a l val val,
并基于 v a l val val左右旋转,防止退化;
下图来自OI-Wiki
左右旋操作配合代码一起食用…(很简单的)
例题一
Code
#include <iostream>
#include <cstdio>
#include <ctime>
#include <algorithm>
using namespace std;
typedef long long ll;
const int N = 1e5 + 10;
struct Node{
int lc,rc;
//val是随机生成的,调节树的平衡,防止题目卡我们
//val是堆的值,我们维护的是大根堆,大的在上面
int key,val;
int cnt,size;//cnt是key值的数量,size是整棵树的大小;
}tr[N];//平衡树最多N个结点
const int INF = 1e9;
int idx,root;
int get_node(int key){
int p = ++idx;//链式存储
tr[p].key = key;
tr[p].val = rand();
tr[p].size = tr[p].cnt = 1;
return p;
}
void push_up(int p){
tr[p].size = tr[tr[p].lc].size + tr[tr[p].rc].size + tr[p].cnt;
}
//左旋,传引用是因为root会变,我们需要同步更新
void Lrotate(int &p){
//结合图来理解
int q = tr[p].rc;
tr[p].rc = tr[q].lc;
tr[q].lc = p;
p = q;
//注意先更新儿子节点
push_up(tr[p].lc);
push_up(p);
}
void Rrotate(int &p){
//结合图理解
int q = tr[p].lc;
tr[p].lc = tr[q].rc;
tr[q].rc = p;
p = q;
push_up(tr[p].rc);
push_up(p);
}
void build(){
//添加两个哨兵-INF与INF
root = get_node(-INF);
tr[root].rc = get_node(INF);
push_up(root);
if(tr[root].val < tr[tr[root].rc].val){
//左旋
Lrotate(root);
}
}
void insert(int &p,int key){
if(!p){
p = get_node(key);
return;
}
if(tr[p].key == key){
++tr[p].cnt;
}else if(tr[p].key > key){
//去左子树
insert(tr[p].lc,key);
if(tr[tr[p].lc].val > tr[p].val){
//右旋
Rrotate(p);
}
}else{
//去右子树
insert(tr[p].rc,key);
if(tr[tr[p].rc].val > tr[p].val){
//左旋
Lrotate(p);
}
}
push_up(p);
return;
}
void erase(int &p,int key){
if(!p) return;//不存在
if(tr[p].key == key){
if(tr[p].cnt > 1) --tr[p].cnt;
//如果不是叶子节点要删除的话
//通过旋转让它增加深度
else if(tr[p].lc || tr[p].rc){
if(!tr[p].rc || tr[tr[p].lc].val > tr[tr[p].rc].val){
//右旋
Rrotate(p);
erase(tr[p].rc,key);//要删除的点跑到右子树去了
}
else{
//左旋
Lrotate(p);
erase(tr[p].lc,key);
}
}else{
//如果是叶子节点,直接标为0
p = 0;
}
}
else if(tr[p].key > key){
//左子树
erase(tr[p].lc,key);
}else{
//右子树
erase(tr[p].rc,key);
}
push_up(p);
}
int query_rank_by_key(int p,int key){
//找不到的话,根据题目返回
if(!p) return 0;
if(tr[p].key == key){
//左子树的全部加一就是我们了
return tr[tr[p].lc].size + 1;
}
if(tr[p].key > key){
//去左子树
return query_rank_by_key(tr[p].lc,key);
}else{
//右子树,需要先去掉左树以及根
return tr[tr[p].lc].size
+ tr[p].cnt + query_rank_by_key(tr[p].rc,key);
}
}
int query_key_by_rank(int p,int rank){
//不存在根据题目返回
if(!p) return INF;
//在左树
if(tr[tr[p].lc].size >= rank) return query_key_by_rank(tr[p].lc,rank);
//在当前节点,注意顺序
if(tr[p].cnt + tr[tr[p].lc].size >= rank) return tr[p].key;
//在右树
return query_key_by_rank(tr[p].rc,rank-tr[p].cnt-tr[tr[p].lc].size);
}
//小于x的最大的数
int query_prev(int p,int key){
//不存在根据题目返回
if(!p) return -INF;
//严格小于的话去左树
if(tr[p].key >= key){
return query_prev(tr[p].lc,key);
}
//否则可能是根也可能在右树
//右子树可能不存在,所以要把根带上
return max(tr[p].key,query_prev(tr[p].rc,key));
}
//大于x的最小的数
int query_next(int p,int key){
if(!p) return INF;
if(tr[p].key <= key){
return query_next(tr[p].rc,key);
}
return min(tr[p].key,query_next(tr[p].lc,key));
}
int main(){
std::ios::sync_with_stdio(false),cin.tie(0),cout.tie(0);
build();
int n;
cin >> n;
int opt,x;
while(n--){
cin >> opt >> x;
if(opt == 1){
//从root开始找,下同
insert(root,x);
}else if(opt == 2){
erase(root,x);
}else if(opt == 3){
//因为哨兵的存在,因此需要减去一
cout << query_rank_by_key(root,x)-1 << '\n';
}else if(opt == 4){
//同理,因为哨兵的存在,因此需要加一
cout << query_key_by_rank(root,x+1) << '\n';
}else if(opt == 5){
cout << query_prev(root,x) << '\n';
}else{
cout << query_next(root,x) << '\n';
}
}
return 0;
}
例题二、
思路
这题就要求三个操作;
- 插入某数
- 求小于等于key最小的数
- 求大于等于key最大的数
可以发现我们可以用set
来完成,下面分别有手写平衡树和set
的代码;
手写平衡树
#include <iostream>
#include <cstdio>
#include <cstdlib>
#include <algorithm>
using namespace std;
typedef long long ll;
const int N = 1e6 + 10,INF = 1e7;
struct Node{
int lc,rc;
int key,val;
}tr[N];
ll ans;
int n,idx,root;
int new_node(int key){
int p = ++idx;
tr[p].key = key;
tr[p].val = rand();
return p;
}
void Rrotate(int &p){
int q = tr[p].lc;
tr[p].lc = tr[q].rc;
tr[q].rc = p;
p = q;
}
void Lrotate(int &p){
int q = tr[p].rc;
tr[p].rc = tr[q].lc;
tr[q].lc = p;
p = q;
}
void insert(int &p,int key){
if(!p){
p = new_node(key);
return;
}
if(tr[p].key == key) return;
if(tr[p].key > key){
insert(tr[p].lc,key);
if(tr[tr[p].lc].val > tr[p].val){
//右旋
Rrotate(p);
}
}
else{
insert(tr[p].rc,key);
if(tr[tr[p].rc].val > tr[p].val){
//左旋
Lrotate(p);
}
}
}
void build(){
root = new_node(-INF);
tr[root].rc = new_node(INF);
if(tr[root].val < tr[tr[root].rc].val){
//左旋
Lrotate(root);
}
}
int get_prev(int p,int key){ //求一个小于等于key最大的数
if(!p) return -INF;
if(tr[p].key > key){
return get_prev(tr[p].lc,key);
}
return max(tr[p].key,get_prev(tr[p].rc,key));
}
int get_next(int p,int key){ //求一个大于等于x最小的数
if(!p) return INF;
if(tr[p].key < key) return get_next(tr[p].rc,key);
return min(tr[p].key,get_next(tr[p].lc,key));
}
void solve(){
build();
cin >> n;
for(int i=1,x;i<=n;++i){
cin >> x;
if(i == 1) ans += x;
else{
ans += min(x - get_prev(root,x)
,get_next(root,x) - x );
}
//插入key
insert(root,x);
}
cout << ans << '\n';
}
int main(){
std::ios::sync_with_stdio(false),cin.tie(0),cout.tie(0);
solve();
return 0;
}
set
#include <iostream>
#include <cstdio>
#include <set>
#include <algorithm>
using namespace std;
typedef long long ll;
const int N = 1e5 + 10,INF = 1e7;
void solve(){
int n;
cin >> n;
set<int> s;
//哨兵
s.insert(-INF);
s.insert(INF);
ll ans = 0;
for(int i=1,key;i<=n;++i){
cin >> key;
if(i == 1){
ans += key;
}else{
//求小于等于key最小的数 以及大于等于key最大的数
auto mx = s.lower_bound(key);
auto mn = mx;
--mn;
ans += min(*mx-key,key-*mn);
}
s.insert(key);
}
cout << ans << '\n';
}
int main(){
std::ios::sync_with_stdio(false),cin.tie(0),cout.tie(0);
solve();
return 0;
}