Treap(树堆)

Treap(树堆)

Treap(树堆)是一种平衡二叉搜索树实现。分为带旋转的Treap和无旋转的Treap。其中,无旋转的Treap相比于AVL来说,其实现简单,方便快捷,是OI竞赛中的不二之选,但是因为其依靠随机化的因素,其平衡是期望平衡的,最坏情况退化为链表,但概率不大。其中无旋Treap(以下Treap均指无旋Treap),分为两个重要的操作,分别为Split(分割)、merge(合并)。

本文以为P3369例题,讲解其实现。

节点结构体

其节点和二叉树基本相同,其中prior成员需要特别注意,这个成员在合并的时候作为合并参考值。

struct Node
{
    int prior; // 节点优先级
    int key;   // 节点关键字
    int l;     // 左儿子节点
    int r;     // 右儿子节点
    int siz;   // 子树的大小
    int cnt;   // 节点的重度
} t[1000000];

以下为分配节点代码,其中,我们注意到:

t[tot].prior = rand();

prior为随机值,所以我们说Treap是期望平衡的。

int tot = 0;

int createNode(int key)
{
    tot++;
    t[tot].key = key;
    t[tot].prior = rand();
    return tot;
}

Split(分割)

函数原型如下:

pair<int, int> split(int root, int key);

给定一个Treap树,其root为根节点,将这颗树分解成两个Treap树,其中pair.first树的所有节点的权值都小于等于key,pair.second树的所有节点的权值都大于key。

我们判断root.key与key的关系,如果key<root.key那么说明根节点root连带其右子树都是pair.second的部分,我们递归分解左子树,然后将分解左子树的结果记为sp的sp.second作为root的新左子树,然后返回(sp.first,root)即可。

同样的,key大于等于root.key的时候,我们分解其右子树,以相对称的方式合并子树即可。

pair<int, int> split(int root, int key)
{
    if (root == 0)
    {
        return make_pair(0, 0);
    }

    if (key < t[root].key)
    {
        pair<int, int> lt = split(t[root].l, key);
        t[root].l = lt.second;
        t[root].siz = t[t[root].l].siz + t[t[root].r].siz + t[root].cnt;

        return make_pair(lt.first, root);
    }
    else
    {
        pair<int, int> rt = split(t[root].r, key);
        t[root].r = rt.first;
        t[root].siz = t[t[root].l].siz + t[t[root].r].siz + t[root].cnt;

        return make_pair(root, rt.second);
    }
}

Merge(合并)

函数原型如下:

int merge(int u, int v);

合并两个Treap子树u和v,其中要求u的所有节点的key都小于等于v所有节点的key。

合并两个Treap有两种不同的方法,采取哪种方法取决于uv中prior的大小的关系,如果,u的prior大于v,那么,将其u的右子树和v整个合并,作为u的新右子树。

合并

同样的,如果u的key值小于等于v的key值,那么就合并v的左子树和u作为v的新左子树即可。

int merge(int u, int v)
{
    if (u == 0)
        return v;
    if (v == 0)
        return u;
    if (t[u].prior > t[v].prior)
    {
        t[u].r = merge(t[u].r, v);
        t[u].siz = t[t[u].l].siz + t[t[u].r].siz + t[u].cnt;
        return u;
    }
    else
    {
        t[v].l = merge(u, t[v].l);
        t[v].siz = t[t[v].l].siz + t[t[v].r].siz + t[v].cnt;
        return v;
    }
}

我们还发现,prior的规则符合堆的规则,即两个子节点的prior的值必定小于等于父节点的key值,其具有堆的特性,又具有搜索二叉树的形式,所以叫Treap(树堆),其中可以证明,将一个随机序列建堆,我们得到的堆的形式更符合平衡树的形式。

提取节点

如果我们想提取树堆中的一个节点,我们该怎么做呢?

我们先将堆分割成两个部分,按照key-1分割,记为 s p sp sp,其中 s p . f i r s t sp.first sp.first的节点值全部小于等于key-1, s p . s e c o n d sp.second sp.second的值全部大于key-1,然后,我们再次按key分割 s p . s e c o n d sp.second sp.second,记为 s s p ssp ssp,那么 s s p . f i r s t ssp.first ssp.first的值将是权值为key的节点, s s p . s e c o n d ssp.second ssp.second的节点都是大于key的节点。

这样 s s p . f i r s t ssp.first ssp.first就是我们想要的节点了。

pair<int, int> sp1 = split(root, key - 1);
pair<int, int> sp2 = split(sp1.second, key);

插入节点

int insert(int root, int key)
{
    pair<int, int> sp1 = split(root, key - 1);
    pair<int, int> sp2 = split(sp1.second, key);
    if (sp2.first == 0)
    {
        sp2.first = createNode(key);
    }

    t[sp2.first].siz++;
    t[sp2.first].cnt++;

    return merge(merge(sp1.first, sp2.first), sp2.second);
}

删除节点

int del(int root, int key)
{
    pair<int, int> sp1 = split(root, key - 1);
    pair<int, int> sp2 = split(sp1.second, key);
    if (sp2.first != 0)
    {
        if (t[sp2.first].cnt == 1)
        {
            sp2.first = 0;
        }
        else
        {
            t[sp2.first].cnt--;
            t[sp2.first].siz--;
        }
    }

    return merge(merge(sp1.first, sp2.first), sp2.second);
}

排名查询

int count(int root, int key, int &ans)
{
    pair<int, int> sp = split(root, key - 1);
    ans = t[sp.first].siz + 1;
    return merge(sp.first, sp.second);
}
int query(int root, int x)
{
    if (t[t[root].l].siz < x && t[t[root].l].siz + t[root].cnt >= x)
    {
        return t[root].key;
    }
    if (t[t[root].l].siz >= x)
        return query(t[root].l, x);
    return query(t[root].r, x - t[t[root].l].siz - t[root].cnt);
}

查询前驱、后继节点

int prv(int root, int key, int &ans)
{
    pair<int, int> sp = split(root, key - 1);
    int curr = sp.first;
    while (curr != 0 && t[curr].r != 0)
    {
        curr = t[curr].r;
    }
    ans = t[curr].key;
    return merge(sp.first, sp.second);
}

int nxt(int root, int key, int &ans)
{
    pair<int, int> sp = split(root, key);
    int curr = sp.second;
    while (curr != 0 && t[curr].l != 0)
    {
        curr = t[curr].l;
    }
    ans = t[curr].key;
    return merge(sp.first, sp.second);
}

模板

#include <bits/stdc++.h>

using namespace std;

#define FR freopen("in.txt", "r", stdin)
#define FW freopen("out11.txt", "w", stdout)

#define MOD 998244353

typedef long long ll;

struct Node
{
    int prior; // 节点优先级
    int key;   // 节点关键字
    int l;     // 左儿子节点
    int r;     // 右儿子节点
    int siz;   // 子树的大小
    int cnt;   // 节点的重度
} t[1000000];

int tot = 0;

int createNode(int key)
{
    tot++;
    t[tot].key = key;
    t[tot].prior = rand();
    return tot;
}

pair<int, int> split(int root, int key)
{
    if (root == 0)
    {
        return make_pair(0, 0);
    }

    if (key < t[root].key)
    {
        pair<int, int> lt = split(t[root].l, key);
        t[root].l = lt.second;
        t[root].siz = t[t[root].l].siz + t[t[root].r].siz + t[root].cnt;

        return make_pair(lt.first, root);
    }
    else
    {
        pair<int, int> rt = split(t[root].r, key);
        t[root].r = rt.first;
        t[root].siz = t[t[root].l].siz + t[t[root].r].siz + t[root].cnt;

        return make_pair(root, rt.second);
    }
}

int merge(int u, int v)
{
    if (u == 0)
        return v;
    if (v == 0)
        return u;
    if (t[u].prior > t[v].prior)
    {
        t[u].r = merge(t[u].r, v);
        t[u].siz = t[t[u].l].siz + t[t[u].r].siz + t[u].cnt;
        return u;
    }
    else
    {
        t[v].l = merge(u, t[v].l);
        t[v].siz = t[t[v].l].siz + t[t[v].r].siz + t[v].cnt;
        return v;
    }
}

int insert(int root, int key)
{
    pair<int, int> sp1 = split(root, key - 1);
    pair<int, int> sp2 = split(sp1.second, key);
    if (sp2.first == 0)
    {
        sp2.first = createNode(key);
    }

    t[sp2.first].siz++;
    t[sp2.first].cnt++;

    return merge(merge(sp1.first, sp2.first), sp2.second);
}

int del(int root, int key)
{
    pair<int, int> sp1 = split(root, key - 1);
    pair<int, int> sp2 = split(sp1.second, key);
    if (sp2.first != 0)
    {
        if (t[sp2.first].cnt == 1)
        {
            sp2.first = 0;
        }
        else
        {
            t[sp2.first].cnt--;
            t[sp2.first].siz--;
        }
    }

    return merge(merge(sp1.first, sp2.first), sp2.second);
}

int count(int root, int key, int &ans)
{
    pair<int, int> sp = split(root, key - 1);
    ans = t[sp.first].siz + 1;
    return merge(sp.first, sp.second);
}

void update(int root)
{
    if (root == 0)
        return;
    update(t[root].l);
    update(t[root].r);

    t[root].siz = t[t[root].l].siz + t[t[root].r].siz + t[root].cnt;
}

int query(int root, int x)
{
    if (t[t[root].l].siz < x && t[t[root].l].siz + t[root].cnt >= x)
    {
        return t[root].key;
    }
    if (t[t[root].l].siz >= x)
        return query(t[root].l, x);
    return query(t[root].r, x - t[t[root].l].siz - t[root].cnt);
}

int prv(int root, int key, int &ans)
{
    pair<int, int> sp = split(root, key - 1);
    int curr = sp.first;
    while (curr != 0 && t[curr].r != 0)
    {
        curr = t[curr].r;
    }
    ans = t[curr].key;
    return merge(sp.first, sp.second);
}

int nxt(int root, int key, int &ans)
{
    pair<int, int> sp = split(root, key);
    int curr = sp.second;
    while (curr != 0 && t[curr].l != 0)
    {
        curr = t[curr].l;
    }
    ans = t[curr].key;
    return merge(sp.first, sp.second);
}

int ans = 0;

void output(int root)
{
    if (root == 0)
        return;
    output(t[root].l);
    if (t[root].key == 815809)
        printf(" !%d! ", ans);
    printf("%d:%d ", t[root].key, t[root].cnt);
    ans += t[root].cnt;
    output(t[root].r);
}

int main()
{
    srand(time(NULL));
    srand(rand());
    srand(rand());
    srand(rand());

    int T;
    int root = 0;
    scanf("%d", &T);
    int ans = 0;
    while (T--)
    {
        int op;
        int x;
        scanf("%d %d", &op, &x);
        switch (op)
        {
        case 1:
            root = insert(root, x);
            break;
        case 2:
            root = del(root, x);
            break;
        case 3:
            root = count(root, x, ans);
            printf("%d\n", ans);
            break;
        case 4:
            ans = query(root, x);
            printf("%d\n", ans);
            break;
        case 5:
            root = prv(root, x, ans);
            printf("%d\n", ans);
            break;
        case 6:
            root = nxt(root, x, ans);
            printf("%d\n", ans);
            break;
        default:
            break;
        }
    }
    return 0;
}
评论 2
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值