平衡树之短小精悍的 AVL 树

之所以有这篇博客是因为网上的 AVL 树代码普遍很长,讲解也很复杂,实际上这是一个很显然且容易编写的数据结构。本篇博客使用数组实现 AVL 树,指针同理,可以用作 AVL 树入门。

思想

平衡树是二叉搜索树(BST)的改进,二叉搜索树每个点维护 v a l u val_u valu 表示结点 u u u 的权值, l s x ls_x lsx r s x rs_x rsx 分别表示结点 u u u 的左儿子和右儿子,满足 v a l l s x < v a l x < v a l r s x val_{ls_x}<val_x<val_{rs_x} vallsx<valx<valrsx。即左儿子权值小于父亲权值小于右儿子权值。

BST 在极端数据下会出现失衡的状况,于是有了平衡树。

AVL 树是平衡树的一种,令重量 s z x sz_x szx 表示以 x x x 为根的子树大小,则 ∣ s z l s x − s z r s x ∣ ≤ 1 |sz_{ls_x}-sz_{rs_x}|\le1 szlsxszrsx1,即左右儿子重量之差的绝对值不超过 1 1 1

AVL 树通过旋转来维护树的平衡。

旋转

AVL 树的旋转只有一种,就是单旋,对于结点 x x x 和其父亲 f a fa fa,将结点 x x x 向上旋转就是让 x x x 作为 f a fa fa 的父亲,原来 x x x 连接的一个儿子 v v v 作为 f a fa fa 的儿子。

初学建议画图。

另外,可能有重复的值,用 c n t x cnt_x cntx 表示 v a l x val_x valx 的个数, s z x sz_x szx 表示子树 x 值的个数, n s z x nsz_x nszx 表示子树 x x x 结点个数, c h x , 0 / 1 ch_{x,0/1} chx,0/1 表示结点 x 的左 / 右儿子。

// 维护 sz 和 nsz
void pushup(int x) { sz[x] = sz[ls] + sz[rs] + cnt[x], nsz[x] = nsz[ls] + nsz[rs] + 1; }

 void spin(int x, int & fa) {	// 旋转操作
    int d = ch[fa][1] == x;	// d 表示 x 是 fa 的左/右儿子
    ch[fa][d] = ch[x][d ^ 1], ch[x][d ^ 1] = fa;
    pushup(fa), pushup(x), fa = x;
}

维持平衡

当不满足 ∣ n s z l s x − n s z r s x ∣ ≤ 1 |nsz_{ls_x}-nsz_{rs_x}|\le1 nszlsxnszrsx1 时,则将其进行旋转。

正常来说,应该先将其转成都往一边偏的形状,再将结点 x x x 重的儿子进行旋转:

      x				    x              d
     / \               / \            / \
    a   b             d   b          a   x
   / \		-->      / \     -->    / \ / \
  c   d             a   f          c  e f  b
     / \           / \
    e   f         c   e

代码理应是:

void maintain(int & x) {	// 维持平衡
    int d = abs(nsz[ls] - nsz[rs]) > 1 ? nsz[ls] < nsz[rs] : -1;
    if(d == -1) return;
    int & c = ch[x][d];
    if(nsz[ch[c][d]] < nsz[ch[c][d ^ 1]]) spin(ch[c][d ^ 1], c);
    spin(c, x);
}

但实际测试和直接将结点 x x x 的重的儿子进行旋转效率差不多,甚至后者还更快(这里当初写错了几次,结果改的复杂度越正确跑得越慢):

void maintain(int & x) {	// 维持平衡
    if(nsz[ls] > nsz[rs] + 1) spin(ls, x);
    if(nsz[rs] > nsz[ls] + 1) spin(rs, x);
}

因此采用下面这种。

插入

现在要插入值 v v v,对于结点 x x x,若 v < v a l x v<val_x v<valx,则向左递归;若 v > v a l x v>val_x v>valx,则向右递归,若 v = v a l x v=val_x v=valx,则令 c n t x ← c n t x + 1 cnt_x\leftarrow cnt_x+1 cntxcntx+1 并返回。当递归到叶子结点时将其作为叶子节点的儿子。

// 添加结点,cl 用来回收删除的结点(不持久化的话其实没必要)
void add_node(int & x, int v) { val[x = cln ? cl[cln--] : ++tot] = v, ls = rs = 0, sz[x] = nsz[x] = 1, cnt[x] = 1; }

void ins(int & x, int v) {	// 插入
    if(!x) return add_node(x, v), void();
    if(v == val[x]) return ++cnt[x], ++sz[x], void();
    v < val[x] ? ins(ls, v) : ins(rs, v);
    maintain(x), pushup(x);
}

删除

思路是找到了要删除的结点 x x x 后,若 c n t x > 1 cnt_x>1 cntx>1,则将其减一,否则将 x x x 旋转成叶子结点并删除。

若旋转到某一时刻使得 x x x 没有左儿子或者没有右儿子,可以直接删除结点 x x x 并返回。

尽量旋转 n s z nsz nsz 较大的儿子,这样能保证树不会失衡。

void del(int & x, int v) {	// 删除
    if(!x) return;
    if(v == val[x]) {	// 找到了结点 x
        if(cnt[x] > 1) return --cnt[x], --sz[x], void();	// cnt[x]>1 就直接将其减 1
        if(!ls || !rs) return cl[++cln] = x, x = ls | rs, void();	// 没有左儿子或右儿子就直接删除结点 x
        spin(sz[ls] > sz[rs] ? ls : rs, x), del(x, v);	// 否则将其旋转
    } else del(v < val[x] ? ls : rs, v);
    pushup(x);
}

获取排名

现在要获取值 v v v 的排名(比 v v v 小的数的个数 + 1 +1 +1),可以在树上二分,具体见代码。

int grk(int x, int v) {		// 获取比 v 小的数的个数
    if(!x) return 0;
    if(v <= val[x]) return grk(ls, v);	// 若 v<=val[x],则递归左子树
    return grk(rs, v) + sz[ls] + cnt[x];	// 否则,答案为递归右子树得到的答案,加上左子树大小和 cnt[x]
}

第 k 大 / 小

以第 k 小为例,同样在树上二分。

int kth(int x, int rk) {	// 获取第 rk 小
    if(ls && rk <= sz[ls]) return kth(ls, rk);
    if(rk <= sz[ls] + cnt[x]) return val[x];
    return kth(rs, rk - sz[ls] - cnt[x]);
}

求前驱 / 后继

一个数的前驱指的是比这个数小且最大的数,后继指的是比这个数大且最小的数。

怎么求呢?有个简单方法,比如现在 AVL 树维护的是 { 1 , 2 , 2 , 2 , 3 , 3 , 3 , 4 , 5 } \{1,2,2,2,3,3,3,4,5\} {1,2,2,2,3,3,3,4,5},如果要求 3 3 3 的前驱,那么先求出比 3 3 3 小的数的个数,即 { 1 , 2 , 2 , 2 } \{1,2,2,2\} {1,2,2,2} 4 个,然后求第 4 4 4 小,就求出了 2 2 2,即 3 3 3 的前驱。求后继同理。

这样写经实际测试常数并不是很大,可以接受。

int pre(int v) { return kth(rt, grk(rt, v)); }	// 求前驱

int nxt(int v) { return kth(rt, grk(rt, v + 1) + 1); }	// 求后继

完整代码

const int N = 1100010;

namespace AVL {
    int val[N], ch[N][2], sz[N], nsz[N], tot, cnt[N], rt;
    int cl[N], cln;

    void pushup(int x) { sz[x] = sz[ls] + sz[rs] + cnt[x], nsz[x] = nsz[ls] + nsz[rs] + 1; }

    void spin(int x, int & fa) {
        int d = ch[fa][1] == x;
        ch[fa][d] = ch[x][d ^ 1], ch[x][d ^ 1] = fa;
        pushup(fa), pushup(x), fa = x;
    }

    void maintain(int & x) {
        if(nsz[ls] > nsz[rs] + 1) spin(ls, x);
        if(nsz[rs] > nsz[ls] + 1) spin(rs, x);
    }

    void add_node(int & x, int v) { val[x = cln ? cl[cln--] : ++tot] = v, ls = rs = 0, sz[x] = nsz[x] = 1, cnt[x] = 1; }

    void ins(int & x, int v) {
        if(!x) return add_node(x, v), void();
        if(v == val[x]) return ++cnt[x], ++sz[x], void();
        v < val[x] ? ins(ls, v) : ins(rs, v);
        maintain(x), pushup(x);
    }

    void del(int & x, int v) {
        if(!x) return;
        if(v == val[x]) {
            if(cnt[x] > 1) return --cnt[x], --sz[x], void();
            if(!ls || !rs) return cl[++cln] = x, x = ls | rs, void();
            spin(sz[ls] > sz[rs] ? ls : rs, x), del(x, v);
        } else del(v < val[x] ? ls : rs, v);
        pushup(x);
    }

    int grk(int x, int v) {
        if(!x) return 0;
        if(v <= val[x]) return grk(ls, v);
        return grk(rs, v) + sz[ls] + cnt[x];
    }

    int kth(int x, int rk) {
        if(ls && rk <= sz[ls]) return kth(ls, rk);
        if(rk <= sz[ls] + cnt[x]) return val[x];
        return kth(rs, rk - sz[ls] - cnt[x]);
    }

    int pre(int v) { return kth(rt, grk(rt, v)); }

    int nxt(int v) { return kth(rt, grk(rt, v + 1) + 1); }
}

共 54 行且没有刻意压行,应该算是比较简短的,相比 Splay,替罪羊树,Treap,FHQ Treap,WBLT 之类的好写多了,效率很可观。

  • 24
    点赞
  • 15
    收藏
    觉得还不错? 一键收藏
  • 1
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值