TREAP 的实现及应用
概念
treap是tree与heap的合成的单词,顾名思义,它既有二叉查找树(tree)的性质, 又符合堆(heap)的性质。
treap中每个节点有两个属性 A A A和 B B B,一个属性符合堆的性质,而另一个属性符合树的性质。如下所示的treap。设字母代表属性 A A A, 数字代表属性 B B B。 其中 A A A属性符合堆的性质:任意一个节点的 A A A属性均小于其儿子;而 B B B属性符合二叉查找树的性质:任意节点的 B B B均大于其左子树中所有节点的 B B B值,且小于其右子树中所有节点的 B B B值,即中序序列为一个有序序列。
让我们回顾一下二叉查找树的构建过程,如果插入的节点,其权值恰好是严格递增或者严格递减的,我们会得到一条链。而如果数据随机,插入的节点其权值没有明显的规律,则基本不会得到一条链,更大概率是比较平衡的一棵二叉树。
同样的道理,如果我们人为给每个节点增加一个属性,这个属性的值是随机给出的,利用这个随机属性来排序,势必会得到一棵比较平衡的二叉树。
那如果树中的节点是动态变化的,有插入、删除操作,则排序就不合适了。但我们还是可以利用该随机属性来影响树的形态,treap即应运而生了。
让这个随机属性,满足heap的性质,即父亲的该属性大于等于(或小于等于)儿子即可。这样树的形态仍然会是比较平衡的。
treap是一种弱平衡的树: 因为它的一个属性是随机分配的,通过随机性来保证不会过度失衡,但也不能保证它严格平衡。
看一道例题:
题目描述 普通平衡树
你需要写一种数据结构(可参考题目标题),来维护一些数,其中需要提供以下操作:
1. 插入一个整数x
2. 删除一个整数x(若有多个相同的数,只删除一个)
3. 查询整数x的排名(若有多个相同的数,输出最小的排名),相同的数依次排名,不并列排名
4. 查询排名为x的数,排名的概念同3
5. 求x的前驱(前驱定义为小于x,且最大的数),保证x有前驱
6. 求x的后继(后继定义为大于x,且最小的数),保证x有后继
输入格式
第一行为n,表示操作的个数(n <= 500000)
下面n行每行有两个数opt和x,opt表示操作的序号(1<=opt<=6, -10^7 <= x <= 10^7)
大规模输入数据,建立读入优化
输出格式
对于操作3,4,5,6每行输出一个数,表示对应答案
数据规模
$ 1 \leq n \leq 500000$
采用treap来解决这个题。
节点
节点用结构体表示,其中包括左右儿子,节点的值,节点的优先级等
节点的优先级采用随机数。
struct node{
int ch[2], val, pri, sz;
}arr[MAXN];
旋转操作
旋转操作有一个技巧,如果是从上而下递归做的,以旋转边上方的点作为参数,实现起来相对方便。
void rote(int &r, int flg){
int t = arr[r].ch[1-flg];
arr[r].ch[1-flg] = arr[t].ch[flg];
arr[t].ch[flg] = r;
pushup(r);
pushup(t);
r = t;
}
插入节点
插入函数与普通的二叉查找树差不多,只是当儿子的优先级小于父亲的优先级时,需要做旋转。
void insert(int &r, int x){
if(r == 0){
arr[++pcnt].val = x, arr[pcnt].pri = rand(), ++arr[pcnt].sz;
r = pcnt;
pushup(r);
return;
}
if(x <= arr[r].val) {
insert(arr[r].ch[0], x);
if(arr[r].ch[0] && arr[r].pri > arr[arr[r].ch[0]].pri) rote(r, 1); //right tot
}
else {
insert(arr[r].ch[1], x);
if(arr[r].ch[1] && arr[r].pri > arr[arr[r].ch[1]].pri) rote(r, 0); // left rot
}
pushup(r);
}
删除操作
删除操作和普通的二叉查找树完全一样,不需要旋转。
首先递归地找到待删除节点,找到了以后,判断一下:若待删除节点不足两个儿子时,直接删除,儿子取代它的位置;否则,在左子树中找最大的一个节点(该节点一定没有右儿子),将该节点剪切下来替代待删除节点。
这里有一个地方要格外注意,如果树中存在多个值相同的节点,要注意删除操作是一次删一个还是一次删掉所有。
更好的方法是将多个值相同的节点合并为一个节点,增加一个属性,用于表示该节点出现的次数。
//此处treap中存储了重复节点,删除时采用的由下往上逐一替换,最终保证删除1个节点
int del(int &r, int x){
int tmp;
if(arr[r].val == x || arr[r].val > x && arr[r].ch[0] == 0 || arr[r].val < x && arr[r].ch[1] == 0){
if(arr[r].ch[0] == 0 || arr[r].ch[1] == 0) {
tmp = arr[r].val;
r = arr[r].ch[0] + arr[r].ch[1];
return tmp;
}
else{
tmp = arr[r].val;
arr[r].val = del(arr[r].ch[0], x);
}
}
else if(x < arr[r].val) tmp = del(arr[r].ch[0], x);
else tmp = del(arr[r].ch[1], x);
pushup(r);
return tmp;
}
查找第x个元素
int xth(int r, int x){
if(arr[arr[r].ch[0]].sz >= x) return xth(arr[r].ch[0], x);
else if(arr[arr[r].ch[0]].sz + 1 >= x) return arr[r].val;
else return xth(arr[r].ch[1], x - arr[arr[r].ch[0]].sz - 1);
}
询问x的排名
int getrank(int r, int x){
if(r == 0) return 1;
if(x > arr[r].val)
return arr[arr[r].ch[0]].sz + 1 + getrank(arr[r].ch[1], x);
else return getrank(arr[r].ch[0], x);
}
查找前驱
int getpre(int r, int x){
if(r == 0) return -MOD;
if(arr[r].val >= x) return getpre(arr[r].ch[0], x);
else return max(arr[r].val, getpre(arr[r].ch[1], x));
}
查找后继
int getnxt(int r, int x){
if(r == 0) return MOD;
if(arr[r].val <= x)return getnxt(arr[r].ch[1], x);
else return min(arr[r].val, getnxt(arr[r].ch[0], x));
}
完整代码如下:
#include <bits/stdc++.h>
using namespace std;
#define MAXN 1000005
#define MOD 1000000000
struct node{
int ch[2], val, pri, sz;
}arr[MAXN];
int n, m, opt, x, pcnt, rt;
int cntres;
void pushup(int r){
if(r) arr[r].sz = arr[arr[r].ch[0]].sz + arr[arr[r].ch[1]].sz + 1;
}
void rote(int &r, int flg){
int t = arr[r].ch[1-flg];
arr[r].ch[1-flg] = arr[t].ch[flg];
arr[t].ch[flg] = r;
pushup(r);
pushup(t);
r = t;
}
void insert(int &r, int x){
if(r == 0){
arr[++pcnt].val = x, arr[pcnt].pri = rand(), ++arr[pcnt].sz;
r = pcnt;
pushup(r);
return;
}
if(x <= arr[r].val) {
insert(arr[r].ch[0], x);
if(arr[r].ch[0] && arr[r].pri > arr[arr[r].ch[0]].pri) rote(r, 1); //right tot
}
else {
insert(arr[r].ch[1], x);
if(arr[r].ch[1] && arr[r].pri > arr[arr[r].ch[1]].pri) rote(r, 0); // left rot
}
pushup(r);
}
int del(int &r, int x){
int tmp;
if(arr[r].val == x || arr[r].val > x && arr[r].ch[0] == 0 || arr[r].val < x && arr[r].ch[1] == 0){
if(arr[r].ch[0] == 0 || arr[r].ch[1] == 0) {
tmp = arr[r].val;
r = arr[r].ch[0] + arr[r].ch[1];
return tmp;
}
else{
tmp = arr[r].val;
arr[r].val = del(arr[r].ch[0], x);
}
}
else if(x < arr[r].val) tmp = del(arr[r].ch[0], x);
else tmp = del(arr[r].ch[1], x);
pushup(r);
return tmp;
}
bool find(int r, int x){
if(r == 0) return 0;
if(arr[r].val == x) return 1;
else if(x < arr[r].val) return find(arr[r].ch[0], x);
else return find(arr[r].ch[1], x);
}
int xth(int r, int x){
if(arr[arr[r].ch[0]].sz >= x) return xth(arr[r].ch[0], x);
else if(arr[arr[r].ch[0]].sz + 1 >= x) return arr[r].val;
else return xth(arr[r].ch[1], x - arr[arr[r].ch[0]].sz - 1);
}
int getrank(int r, int x){
if(r == 0) return 1;
if(x > arr[r].val) return arr[arr[r].ch[0]].sz + 1 + getrank(arr[r].ch[1], x);
else return getrank(arr[r].ch[0], x);
}
int getpre(int r, int x){
if(r == 0) return -MOD;
if(arr[r].val >= x) return getpre(arr[r].ch[0], x);
else return max(arr[r].val, getpre(arr[r].ch[1], x));
}
int getnxt(int r, int x){
if(r == 0) return MOD;
if(arr[r].val <= x)return getnxt(arr[r].ch[1], x);
else return min(arr[r].val, getnxt(arr[r].ch[0], x));
}
int main(){
srand(time(0));
int rescnt = 0, res = 0;
scanf("%d", &n);
for(int i = 1; i <= n; i++){
scanf("%d %d", &opt, &x);
if(opt == 1){insert(rt, x); cntres++;}
else if(opt == 2) { del(rt, x), --cntres;}
else if(opt == 3) printf("%d\n",getrank(rt, x));
else if(opt == 4) printf("%d\n", xth(rt, x));
else if(opt == 5) printf("%d\n", getpre(rt, x));
else printf("%d\n", getnxt(rt, x));
}
return 0;
}
非旋转的treap
非旋转treap是由范浩强大佬发明的,它摈弃了旋转操作,而是采用分裂和合并操作来作为基本操作,实现插入、删除节点的同时,维护二叉查找树的有序性和堆的性质。
如下图所示,这是一棵treap,每个节点中的第一个数字为权值,第二个数字为优先级。
权值满足二叉查找树的性质,优先级满足堆的性质。
split操作
分裂操作是非旋转treap中最难理解的操作,它实质上是通过一条路径,将树分成两棵树。我们插入节点、或删除节点,都需要先找到一个插入节点的位置或者删除节点的位置,于是就得到了一条从根到该位置的路径(路径上有可能只有1个点)。
以插入一个权值为20的节点为例:首先从根开始,找到新节点对应的位置,它的位置应该位于节点 J J J的右儿子处。于是我们得到一条从 A A A到 J J J的路径,将这条路径上的边全部断开,小于等于20的点分到左侧;大于20的点分到右侧,于是得到了两棵树:如下图:
接下来生成权值为 20 20 20的节点 Q Q Q,优先级假设为 3 3 3.
void update(int rt){
if(rt == 0) return;
tree[rt].sz = tree[tree[rt].ch[0]].sz + tree[tree[rt].ch[1]].sz + 1;
}
void split(int rt, int &xroot, int &yroot, int v){
//rt表示当前准备分裂的子树,其根为rt。
//xroot表示之前分裂得到的左边那棵树的预留的指针,用于指向下一个将加入的新节点。yroot同理,但它是右边那棵树的。
//xroot其实也表示子树rt即将分裂而成的左树的树根,yroot同理。
if(rt == 0) xroot = yroot = 0;
else if(v < tree[rt].val) yroot = rt, split(tree[rt].ch[0], xroot, tree[yroot].ch[0], v);
else xroot = rt, split(tree[rt].ch[1], tree[xroot].ch[1], yroot, v); //等于v的顶点会分到左树里。
update(rt);
}
接下来我们需要用到合并操作merge。
merge操作
合并操作的对象是两棵树,这两棵树一定满足,左边的树权最大值小于右边的树的权值最小值。我们根据其优先级来合并。为了描述方面,我们设左边的树为 L L L,右边的树为 R R R, 首先比较两棵树的树根,谁优先级小,谁就作为新的树根,假设 L L L的优先级较小,则问题转换为 L L L的右子树与 R R R的合并问题了;否则就是 R R R的根作为新树的根,问题转换为 L L L和 R . l s o n R.lson R.lson的合并问题了,这样递归下去,直到某棵树为空,则递归结束。
合并操作比较简单,就不上图了。
void merge(int &rt, int xroot, int yroot){
//rt表示即将合并得到的新树的根,xroot表示当前参与合并的左树的根,yroot表示当前参与合并的右树的根。
if(xroot == 0 || yroot == 0) rt = xroot + yroot;
else if(tree[xroot].pri < tree[yroot].pri){
rt = xroot;
merge(tree[rt].ch[1], tree[rt].ch[1], yroot);
}
else {
rt = yroot;
merge(tree[rt].ch[0], xroot, tree[rt].ch[0]);
}
update(rt);
}
插入操作
插入节点 Q Q Q,可以先将树分裂为两棵树,先将子树 A A A和 Q Q Q合并,再将合并的新树继续与子树 C C C合并即可。
删除操作
删除操作的代码,也是可以用 s p l i t split split和 m e r g e merge merge操作来完成的。比如要删除权值为 x x x的节点。
先通过一次split操作,将权值小于等于 x x x和权值大于 x x x的节点分开,然后再将权值等于 x x x的和小于 x x x的分开,此时得到了三棵树。
接下来将权值小于 x x x的树与权值大于 x x x的树进行 m e r g e merge merge合并,即达到了删除 x x x的效果。
模板题 普通平衡树
你需要写一种数据结构(可参考题目标题),来维护一些数,其中需要提供以下操作:
1. 插入一个整数x
2. 删除一个整数x(若有多个相同的数,只删除一个)
3. 查询整数x的排名(若有多个相同的数,输出最小的排名),相同的数依次排名,不并列排名
4. 查询排名为x的数,排名的概念同3
5. 求x的前驱(前驱定义为小于x,且最大的数),保证x有前驱
6. 求x的后继(后继定义为大于x,且最小的数),保证x有后继
输入格式
第一行为n,表示操作的个数(n <= 500000)
下面n行每行有两个数opt和x,opt表示操作的序号($1 \lt opt \leq 6, -10^7 \leq x \leq 10^7)$
大规模输入数据,建立读入优化
输出格式
对于操作3,4,5,6每行输出一个数,表示对应答案
数据规模
1 ≤ n ≤ 500000 1 \leq n \leq 500000 1≤n≤500000
采用非旋转treap来解决这个题
#include <bits/stdc++.h>
using namespace std;
#define MAXN 1000005
#define INF 999999999
int n;
struct node{
int ch[2], val, pri, sz;
}tree[MAXN];
int rt, root1, root2, tot;
void update(int rt){
if(rt == 0) return;
tree[rt].sz = tree[tree[rt].ch[0]].sz + tree[tree[rt].ch[1]].sz + 1;
}
void split(int rt, int &xroot, int &yroot, int v){
//rt表示当前准备分裂的子树,其根为rt。
//xroot表示之前分裂得到的左边那棵树的预留的指针,用于指向下一个将加入的新节点。yroot同理,但它是右边那棵树的。
//xroot其实也表示子树rt即将分裂而成的左树的树根,yroot同理。
if(rt == 0) xroot = yroot = 0;
else if(v < tree[rt].val) yroot = rt, split(tree[rt].ch[0], xroot, tree[yroot].ch[0], v);
else xroot = rt, split(tree[rt].ch[1], tree[xroot].ch[1], yroot, v); //等于v的顶点会分到左树里。
update(rt);
}
void merge(int &rt, int xroot, int yroot){
//rt表示即将合并得到的新树的根,xroot表示当前参与合并的左树的根,yroot表示当前参与合并的右树的根。
if(xroot == 0 || yroot == 0) rt = xroot + yroot;
else if(tree[xroot].pri < tree[yroot].pri){
rt = xroot;
merge(tree[rt].ch[1], tree[rt].ch[1], yroot);
}
else {
rt = yroot;
merge(tree[rt].ch[0], xroot, tree[rt].ch[0]);
}
update(rt);
}
void insert(int &rt, int v){
split(rt, root1, root2, v);
tree[++tot].val = v, tree[tot].pri = rand(), tree[tot].sz = 1, rt = tot;
merge(root1, root1, tot);
merge(rt, root1, root2);
}
void del(int &rt, int v){
int z;
split(rt, root1, root2, v); //将当前子树rt从节点v处分裂为两棵树,v放入左边的树中。
split(root1, root1, z, v - 1);
merge(z, tree[z].ch[0], tree[z].ch[1]);
merge(rt, root1, z);
merge(rt, rt, root2);
}
int getxth(int rt, int v){
if(v <= tree[tree[rt].ch[0]].sz) return getxth(tree[rt].ch[0], v);
else if(v <= tree[tree[rt].ch[0]].sz + 1) return tree[rt].val;
else return getxth(tree[rt].ch[1], v - tree[tree[rt].ch[0]].sz - 1);
}
int getrank(int rt, int v){
if(rt == 0) return 1;
if(v <= tree[rt].val) return getrank(tree[rt].ch[0], v);
else return tree[tree[rt].ch[0]].sz + 1 + getrank(tree[rt].ch[1], v);
}
int getpre(int rt, int v){
if(rt == 0) return -INF;
if(v <= tree[rt].val) return getpre(tree[rt].ch[0], v);
else{
return max(tree[rt].val, getpre(tree[rt].ch[1], v));
}
}
int getnxt(int rt, int v){
if(rt == 0) return INF;
if(v >= tree[rt].val) return getnxt(tree[rt].ch[1], v);
else return min(tree[rt].val, getnxt(tree[rt].ch[0], v));
}
int main(){
int opt, x;
scanf("%d", &n);
for(int i = 1; i <= n; i++){
scanf("%d %d", &opt, &x);
switch (opt)
{
case 1: insert(rt, x); break;
case 2: del(rt, x); break;
case 4: printf("%d\n",getxth(rt, x)); break;
case 3: printf("%d\n",getrank(rt, x)); break;
case 5: printf("%d\n",getpre(rt, x)); break;
default: printf("%d\n", getnxt(rt,x)); break;
}
}
}