Treap(树堆)
Treap(树堆)是一种平衡二叉搜索树实现。分为带旋转的Treap和无旋转的Treap。其中,无旋转的Treap相比于AVL来说,其实现简单,方便快捷,是OI竞赛中的不二之选,但是因为其依靠随机化的因素,其平衡是期望平衡的,最坏情况退化为链表,但概率不大。其中无旋Treap(以下Treap均指无旋Treap),分为两个重要的操作,分别为Split(分割)、merge(合并)。
本文以为P3369例题,讲解其实现。
节点结构体
其节点和二叉树基本相同,其中prior成员需要特别注意,这个成员在合并的时候作为合并参考值。
struct Node
{
int prior; // 节点优先级
int key; // 节点关键字
int l; // 左儿子节点
int r; // 右儿子节点
int siz; // 子树的大小
int cnt; // 节点的重度
} t[1000000];
以下为分配节点代码,其中,我们注意到:
t[tot].prior = rand();
prior为随机值,所以我们说Treap是期望平衡的。
int tot = 0;
int createNode(int key)
{
tot++;
t[tot].key = key;
t[tot].prior = rand();
return tot;
}
Split(分割)
函数原型如下:
pair<int, int> split(int root, int key);
给定一个Treap树,其root为根节点,将这颗树分解成两个Treap树,其中pair.first树的所有节点的权值都小于等于key,pair.second树的所有节点的权值都大于key。
我们判断root.key与key的关系,如果key<root.key那么说明根节点root连带其右子树都是pair.second的部分,我们递归分解左子树,然后将分解左子树的结果记为sp的sp.second作为root的新左子树,然后返回(sp.first,root)即可。
同样的,key大于等于root.key的时候,我们分解其右子树,以相对称的方式合并子树即可。
pair<int, int> split(int root, int key)
{
if (root == 0)
{
return make_pair(0, 0);
}
if (key < t[root].key)
{
pair<int, int> lt = split(t[root].l, key);
t[root].l = lt.second;
t[root].siz = t[t[root].l].siz + t[t[root].r].siz + t[root].cnt;
return make_pair(lt.first, root);
}
else
{
pair<int, int> rt = split(t[root].r, key);
t[root].r = rt.first;
t[root].siz = t[t[root].l].siz + t[t[root].r].siz + t[root].cnt;
return make_pair(root, rt.second);
}
}
Merge(合并)
函数原型如下:
int merge(int u, int v);
合并两个Treap子树u和v,其中要求u的所有节点的key都小于等于v所有节点的key。
合并两个Treap有两种不同的方法,采取哪种方法取决于uv中prior的大小的关系,如果,u的prior大于v,那么,将其u的右子树和v整个合并,作为u的新右子树。
同样的,如果u的key值小于等于v的key值,那么就合并v的左子树和u作为v的新左子树即可。
int merge(int u, int v)
{
if (u == 0)
return v;
if (v == 0)
return u;
if (t[u].prior > t[v].prior)
{
t[u].r = merge(t[u].r, v);
t[u].siz = t[t[u].l].siz + t[t[u].r].siz + t[u].cnt;
return u;
}
else
{
t[v].l = merge(u, t[v].l);
t[v].siz = t[t[v].l].siz + t[t[v].r].siz + t[v].cnt;
return v;
}
}
我们还发现,prior的规则符合堆的规则,即两个子节点的prior的值必定小于等于父节点的key值,其具有堆的特性,又具有搜索二叉树的形式,所以叫Treap(树堆),其中可以证明,将一个随机序列建堆,我们得到的堆的形式更符合平衡树的形式。
提取节点
如果我们想提取树堆中的一个节点,我们该怎么做呢?
我们先将堆分割成两个部分,按照key-1分割,记为 s p sp sp,其中 s p . f i r s t sp.first sp.first的节点值全部小于等于key-1, s p . s e c o n d sp.second sp.second的值全部大于key-1,然后,我们再次按key分割 s p . s e c o n d sp.second sp.second,记为 s s p ssp ssp,那么 s s p . f i r s t ssp.first ssp.first的值将是权值为key的节点, s s p . s e c o n d ssp.second ssp.second的节点都是大于key的节点。
这样 s s p . f i r s t ssp.first ssp.first就是我们想要的节点了。
pair<int, int> sp1 = split(root, key - 1);
pair<int, int> sp2 = split(sp1.second, key);
插入节点
int insert(int root, int key)
{
pair<int, int> sp1 = split(root, key - 1);
pair<int, int> sp2 = split(sp1.second, key);
if (sp2.first == 0)
{
sp2.first = createNode(key);
}
t[sp2.first].siz++;
t[sp2.first].cnt++;
return merge(merge(sp1.first, sp2.first), sp2.second);
}
删除节点
int del(int root, int key)
{
pair<int, int> sp1 = split(root, key - 1);
pair<int, int> sp2 = split(sp1.second, key);
if (sp2.first != 0)
{
if (t[sp2.first].cnt == 1)
{
sp2.first = 0;
}
else
{
t[sp2.first].cnt--;
t[sp2.first].siz--;
}
}
return merge(merge(sp1.first, sp2.first), sp2.second);
}
排名查询
int count(int root, int key, int &ans)
{
pair<int, int> sp = split(root, key - 1);
ans = t[sp.first].siz + 1;
return merge(sp.first, sp.second);
}
int query(int root, int x)
{
if (t[t[root].l].siz < x && t[t[root].l].siz + t[root].cnt >= x)
{
return t[root].key;
}
if (t[t[root].l].siz >= x)
return query(t[root].l, x);
return query(t[root].r, x - t[t[root].l].siz - t[root].cnt);
}
查询前驱、后继节点
int prv(int root, int key, int &ans)
{
pair<int, int> sp = split(root, key - 1);
int curr = sp.first;
while (curr != 0 && t[curr].r != 0)
{
curr = t[curr].r;
}
ans = t[curr].key;
return merge(sp.first, sp.second);
}
int nxt(int root, int key, int &ans)
{
pair<int, int> sp = split(root, key);
int curr = sp.second;
while (curr != 0 && t[curr].l != 0)
{
curr = t[curr].l;
}
ans = t[curr].key;
return merge(sp.first, sp.second);
}
模板
#include <bits/stdc++.h>
using namespace std;
#define FR freopen("in.txt", "r", stdin)
#define FW freopen("out11.txt", "w", stdout)
#define MOD 998244353
typedef long long ll;
struct Node
{
int prior; // 节点优先级
int key; // 节点关键字
int l; // 左儿子节点
int r; // 右儿子节点
int siz; // 子树的大小
int cnt; // 节点的重度
} t[1000000];
int tot = 0;
int createNode(int key)
{
tot++;
t[tot].key = key;
t[tot].prior = rand();
return tot;
}
pair<int, int> split(int root, int key)
{
if (root == 0)
{
return make_pair(0, 0);
}
if (key < t[root].key)
{
pair<int, int> lt = split(t[root].l, key);
t[root].l = lt.second;
t[root].siz = t[t[root].l].siz + t[t[root].r].siz + t[root].cnt;
return make_pair(lt.first, root);
}
else
{
pair<int, int> rt = split(t[root].r, key);
t[root].r = rt.first;
t[root].siz = t[t[root].l].siz + t[t[root].r].siz + t[root].cnt;
return make_pair(root, rt.second);
}
}
int merge(int u, int v)
{
if (u == 0)
return v;
if (v == 0)
return u;
if (t[u].prior > t[v].prior)
{
t[u].r = merge(t[u].r, v);
t[u].siz = t[t[u].l].siz + t[t[u].r].siz + t[u].cnt;
return u;
}
else
{
t[v].l = merge(u, t[v].l);
t[v].siz = t[t[v].l].siz + t[t[v].r].siz + t[v].cnt;
return v;
}
}
int insert(int root, int key)
{
pair<int, int> sp1 = split(root, key - 1);
pair<int, int> sp2 = split(sp1.second, key);
if (sp2.first == 0)
{
sp2.first = createNode(key);
}
t[sp2.first].siz++;
t[sp2.first].cnt++;
return merge(merge(sp1.first, sp2.first), sp2.second);
}
int del(int root, int key)
{
pair<int, int> sp1 = split(root, key - 1);
pair<int, int> sp2 = split(sp1.second, key);
if (sp2.first != 0)
{
if (t[sp2.first].cnt == 1)
{
sp2.first = 0;
}
else
{
t[sp2.first].cnt--;
t[sp2.first].siz--;
}
}
return merge(merge(sp1.first, sp2.first), sp2.second);
}
int count(int root, int key, int &ans)
{
pair<int, int> sp = split(root, key - 1);
ans = t[sp.first].siz + 1;
return merge(sp.first, sp.second);
}
void update(int root)
{
if (root == 0)
return;
update(t[root].l);
update(t[root].r);
t[root].siz = t[t[root].l].siz + t[t[root].r].siz + t[root].cnt;
}
int query(int root, int x)
{
if (t[t[root].l].siz < x && t[t[root].l].siz + t[root].cnt >= x)
{
return t[root].key;
}
if (t[t[root].l].siz >= x)
return query(t[root].l, x);
return query(t[root].r, x - t[t[root].l].siz - t[root].cnt);
}
int prv(int root, int key, int &ans)
{
pair<int, int> sp = split(root, key - 1);
int curr = sp.first;
while (curr != 0 && t[curr].r != 0)
{
curr = t[curr].r;
}
ans = t[curr].key;
return merge(sp.first, sp.second);
}
int nxt(int root, int key, int &ans)
{
pair<int, int> sp = split(root, key);
int curr = sp.second;
while (curr != 0 && t[curr].l != 0)
{
curr = t[curr].l;
}
ans = t[curr].key;
return merge(sp.first, sp.second);
}
int ans = 0;
void output(int root)
{
if (root == 0)
return;
output(t[root].l);
if (t[root].key == 815809)
printf(" !%d! ", ans);
printf("%d:%d ", t[root].key, t[root].cnt);
ans += t[root].cnt;
output(t[root].r);
}
int main()
{
srand(time(NULL));
srand(rand());
srand(rand());
srand(rand());
int T;
int root = 0;
scanf("%d", &T);
int ans = 0;
while (T--)
{
int op;
int x;
scanf("%d %d", &op, &x);
switch (op)
{
case 1:
root = insert(root, x);
break;
case 2:
root = del(root, x);
break;
case 3:
root = count(root, x, ans);
printf("%d\n", ans);
break;
case 4:
ans = query(root, x);
printf("%d\n", ans);
break;
case 5:
root = prv(root, x, ans);
printf("%d\n", ans);
break;
case 6:
root = nxt(root, x, ans);
printf("%d\n", ans);
break;
default:
break;
}
}
return 0;
}