AVL 平衡树的实现

看了,数据结构与算法分析,发现 AVL 平衡树相对于红黑树,更加简单明了,代码如下:

#ifndef AVLTREE_H
#define AVLTREE_H

#include <vector>

template<typename K, typename V>
class AVLNode{
public:

    AVLNode(){
        left = nullptr;
        right = nullptr;
        parent = nullptr;
        height = 1;
        layer = 0;
    }
    AVLNode(K k, V v){
        left = nullptr;
        right = nullptr;
        parent = nullptr;
        height = 1;
        key = k;
        value = v;
        layer = 0;
    }
    AVLNode *left;
    AVLNode *right;
    AVLNode *parent;
    K key;
    V value;
    int height;
    int layer;

};
template<typename K, typename V>
class AVLTree
{
public:
    AVLTree(int compare(K a, K b),int dump(int height, K &a, V &v))
    {
        root = nullptr;
        this->compare = compare;
        this->dump_t = dump;
    }

    ~AVLTree(){
        deleteNode(root);
    }

    void deleteNode(AVLNode<K,V> *node){
        if(node == nullptr){
            return;
        }
        deleteNode(node->left);
        deleteNode(node->right);
        delete node;
    }

    

    void insert(K key, V value){
          root = insertNode(root, key, value);
    }

    int max(int a, int b){
        if(a > b){
            return  a;
        }else{
            return b;
        }
    }

    int height(AVLNode<K,V> *node){
        if(node == nullptr){
            return 0;
        }
        return node->height;
    }

    int updateHeight(AVLNode<K,V> *node)
    {
         int h = max(height(node->left),height(node->right)) + 1;
         return h;
    }

    AVLNode<K,V> *signalRotateLeftChild(AVLNode<K,V> *k2){
        AVLNode<K,V> *k1 = k2->left;
        k2->left = k1->right;
        if(k2->left){
            k2->left->parent = k2;
        }
        k1->right = k2;


        k1->parent = k2->parent;
        k2->parent = k1;


        k2->height = updateHeight(k2);
        k1->height = updateHeight(k1);

        return k1;
    }

    AVLNode<K,V> *doubleRotateLeftChild(AVLNode<K,V> *k2){
        k2->left = signalRotateRightChild(k2->left);
        if(k2->left){
            k2->left->parent = k2;
        }
        return signalRotateLeftChild(k2);
    }


    AVLNode<K,V> *signalRotateRightChild(AVLNode<K,V> *k2){
        AVLNode<K,V> *k1 = k2->right;
        k2->right = k1->left;
        if(k2->right){
            k2->right->parent = k2;
        }
        k1->left = k2;

        k1->parent = k2->parent;
        k2->parent = k1;

        k2->height = updateHeight(k2);
        k1->height = updateHeight(k1);

        return k1;
    }

    AVLNode<K,V> *doubleRotateRightChild(AVLNode<K,V> *k2){
        k2->right = signalRotateLeftChild(k2->right);
        if(k2->right){
            k2->right->parent = k2;
        }
        return signalRotateRightChild(k2);
    }

    AVLNode<K,V> *insertNode(AVLNode<K,V> *node, K key, V value){

        if(node == nullptr){
            return new AVLNode<K,V>(key,value);
        }
        int result  = compare(key, node->key);
        if(result < 0){
            node->left = insertNode(node->left, key, value);
            if(node->left){
                node->left->parent = node;
            }
        }else if(result > 0){
            node->right = insertNode(node->right, key, value);
            if(node->right){
                node->right->parent = node;
            }
        }else{
            return  node;
        }
        int leftHeight = height(node->left);
        int rightHeight = height(node->right);
        if(leftHeight - rightHeight == 2){
            if(compare(key, node->left->key) < 0){
                node = signalRotateLeftChild(node);
            }else{
                node = doubleRotateLeftChild(node);
            }
        }else  if(rightHeight - leftHeight  == 2){
            if(compare(key, node->right->key) > 0){
                node = signalRotateRightChild(node);
            }else{
                node = doubleRotateRightChild(node);
            }
        }
        node->height = updateHeight(node);
        return node;
    }
    AVLNode<K,V> *findMin(AVLNode<K,V> *node){
        if(node == nullptr) {
            return node;
        }
        if(node->left == nullptr) {
            return node;
        }
        return findMin(node->left);
    }

    void remove(K key){
        root = remove(root,key);
    }

     AVLNode<K,V>  *remove(AVLNode<K,V> *node ,K key){
         if(node == nullptr){
             return node;
         }
       int result = compare(key,node->key);
       if(result == 0) {
           // 右枝还有叶子,找到右枝最小叶子,代替要删除的节点,
           // 否则直接删除此节点
            if(node->right){
               AVLNode<K,V> *rightMin = findMin(node->right);
               if(rightMin == nullptr) {
                   if(node->parent){
                       if(node->parent->left == node){
                           node->parent->left =  node->right;
                       }else{
                           node->parent->right =  node->right;
                       }
                   }
                   AVLNode<K,V> *temp;
                   temp = node->right;
                   if(temp){
                        temp->parent = node->parent;
                   }
                   delete node;
                   node = temp;
               }else {
                   node->key = rightMin->key;
                   node->right = remove(node->right, node->key);
                   if(node->right){
                       node->right->parent = node;
                   }
               }
            }else{
                AVLNode<K,V> *temp;
                temp = node->left;
                if(temp){
                     temp->parent = node->parent;
                }
                delete node;
                node = temp;
            }
       }else if(result < 0) {
           node->left = remove(node->left, key);
           if(node->left){
               node->left->parent = node;
           }
       }else {
           node->right = remove(node->right, key);
           if(node->right){
               node->right->parent = node;
           }
       }

       if(node == nullptr){
           return node;
       }

       if(height(node->left) - height(node->right) == 2) {
           int k1h = height(node->left->left);
           int k2h = height(node->left->right);

           if(k1h > k2h) {
               node = signalRotateLeftChild(node);
           }else {
               node = doubleRotateLeftChild(node);
           }
       }else if(height(node->right) - height(node->left) == 2) {
           int k1h = height(node->right->left);
           int k2h = height(node->right->right);
           if(k2h  > k1h) {
               node = signalRotateRightChild(node);
           }else {
               node = doubleRotateRightChild(node);
           }
       }
       node->height = max( height(node->right) , height(node->left))+1;
       return node;
   }


 AVLNode<K,V> *findMax(AVLNode<K,V> *node){
        if(node == nullptr) {
            return node;
        }
        if(node->right == nullptr) {
            return node;
        }
        return findMax(node->right);
    }

  //这个删除实现可能更好,更加容易理解
 void remove2(K key){
        root = remove2(root,key);
    }

     AVLNode<K,V>  *remove2(AVLNode<K,V> *node ,K key){
         if(node == nullptr){
             return node;
         }
       int result = compare(key,node->key);
       if(result == 0) {
           if(node->right == nullptr && node->left == nullptr){
               delete  node;
               node = nullptr;
           }else if(height(node->right) > height(node->left)){
               AVLNode<K,V> *find = findMin(node->right);
               node->key = find->key;
               node->right = remove2(node->right, node->key);
               if(node->right){
                   node->right->parent = node;
               }
           }else{
               AVLNode<K,V> *find = findMax(node->left);
               node->key = find->key;
               node->left = remove2(node->left, node->key);
               if(node->left){
                   node->left->parent = node;
               }
           }

       }else if(result < 0) {
           node->left = remove2(node->left, key);
           if(node->left){
               node->left->parent = node;
           }
       }else {
           node->right = remove2(node->right, key);
           if(node->right){
               node->right->parent = node;
           }
       }

       if(node == nullptr){
           return node;
       }

       if(height(node->left) - height(node->right) == 2) {
           int k1h = height(node->left->left);
           int k2h = height(node->left->right);

           if(k1h > k2h) {
               node = signalRotateLeftChild(node);
           }else {
               node = doubleRotateLeftChild(node);
           }
       }else if(height(node->right) - height(node->left) == 2) {
           int k1h = height(node->right->left);
           int k2h = height(node->right->right);
           if(k2h  > k1h) {
               node = signalRotateRightChild(node);
           }else {
               node = doubleRotateRightChild(node);
           }
       }
       node->height = max( height(node->right) , height(node->left))+1;
       return node;
   }


void updateLayer()
    {
        updateLayerNode(root,0);
    }

    void updateLayerNode(AVLNode<K,V> *node, int layer){
        if(node == nullptr){
            return;
        }
        node->layer = layer;
        updateLayerNode(node->left,layer+1);
        updateLayerNode(node->right,layer+1);
    }


     void dumpSpace(int h, int layer)
     {
         int count = 1;
         for(int i = 0;i < h+1 - layer;i++){
             count *= 2;
         }
         count = count-4;
         for(int i = 0;i < count;i++){
             printf(" ");
         }
     }

     void dumpSpace2(int h, int layer)
     {
         dumpSpace(h,layer);
         printf("    ");

     }

     void dump3(){
         if(root == nullptr){
             return;
         }
         updateLayer();
         std::vector<AVLNode<K,V> *>list;
         int count = 1;
         list.push_back(root);
         int layer = 1;
         int layerNum = 0;
         int layerCount = 1;

         while(list.size() > 0 && count > 0){
             AVLNode<K,V> *node = list.at(0);
             list.erase(list.begin());
             if(node){
                 count--;
                 if(dump_t){
                     dumpSpace(root->height, layerNum);
                     dump_t(node->height,node->key,node->value);
                     dumpSpace2(root->height, layerNum);
                 }
                 if(node->left){
                     count++;
                     list.push_back(node->left);
                 }else{
                     list.push_back(nullptr);
                 }
                 if( node->right){
                     count++;
                     list.push_back(node->right);
                 }else{
                     list.push_back(nullptr);
                 }
             }else{
                 list.push_back(nullptr);
                 list.push_back(nullptr);

                 dumpSpace(root->height, layerNum);
                 printf("___,");
                 dumpSpace2(root->height, layerNum);
             }
             layerCount--;
             if(layerCount == 0){
                 layerNum++;
                 layer *= 2;
                 layerCount = layer;
                 printf("\n");
             }
         }
         printf("\n");
     }

     void dump(){
         dumpNode(root, 0);
     }

     void dumpNode(AVLNode<K,V> *node, int layer){

         if(!node){
             return;
         }
         int n = node->height;
         for(int i = 0;i <= n;i++){
            printf(" ");
         }
         if(dump_t){
             dump_t(node->height, node->key, node->value);
         }
         dumpNode(node->left,layer+1);
         dumpNode(node->right,layer+1);
         printf("\n");
     }

     void dump2(){
         dumpNode2(root, 0);
     }

     void dumpNode2(AVLNode<K,V> *node, int layer){

         if(!node){
             return;
         }

         dumpNode2(node->left,layer+1);
         if(dump_t){
             dump_t(node->height, node->key, node->value);
         }
         dumpNode2(node->right,layer+1);
     }

private:
    AVLNode<K,V> *root;
    int (*compare)(K a, K b);
    int (*dump_t)(int height, K &a, V &v);
};

#endif // AVLTREE_H

简单测试代码:

#include "avltree.h"

int cmp(int a, int b){
    return  a - b;
}
int dump(int h, int &key, int &value)
{
    printf("%03d,",key);
    return 0;
}

void testAVL()
{
    AVLTree<int, int> *tree = new AVLTree<int,int>(cmp,dump);
    for(int i = 0;i < 20;i++){
       printf("\n-- insert: %d\n",i);
       tree->insert(i,i);
       tree->dump3();
    }
    tree->dump3();
    for(int i = 19;i  >= 0;i--){
        printf("\n-- remove: %d\n",i);
        tree->remove(i);
        tree->dump3();
    }
    tree->dump3();
}

int main(){

    testAVL();
    return 0;
}

  • 1
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 1
    评论
AVL平衡是一种自平衡二叉查找,主要解决二叉查找在插入和删除操作时可能会失去平衡的问题。在AVL平衡中,每个节点都有一个平衡因子,即左子的高度减去右子的高度,平衡因子的值只能为-1、0、1三种。当插入或删除一个节点后,如果导致某个节点的平衡因子的绝对值大于1,就需要通过旋转操作来重新平衡这个节点及其子。 下面是一个非指针实现AVL平衡的示例代码: ```c++ #include <iostream> using namespace std; const int MAXN = 100; class AVLTree { public: AVLTree() : size(0) {} bool insert(int val) { if (size == MAXN) return false; root = insert(root, val); size++; return true; } bool remove(int val) { if (size == 0) return false; root = remove(root, val); size--; return true; } void print() { inorder(root); } private: struct Node { int val, h; Node *l, *r; Node(int v, Node *L = NULL, Node *R = NULL) : val(v), h(0), l(L), r(R) {} int height() { return this ? h : -1; } int balance() { return r->height() - l->height(); } void update() { h = max(l->height(), r->height()) + 1; } Node *rotate_left() { Node *p = r; r = p->l; p->l = this; update(); p->update(); return p; } Node *rotate_right() { Node *p = l; l = p->r; p->r = this; update(); p->update(); return p; } Node *rotate_left_right() { r = r->rotate_right(); return rotate_left(); } Node *rotate_right_left() { l = l->rotate_left(); return rotate_right(); } }; Node *root; int size; Node *insert(Node *p, int val) { if (!p) return new Node(val); if (val < p->val) { p->l = insert(p->l, val); if (p->balance() == 2) { if (val < p->r->val) p = p->rotate_right_left(); else p = p->rotate_left(); } } else if (val > p->val) { p->r = insert(p->r, val); if (p->balance() == -2) { if (val > p->l->val) p = p->rotate_left_right(); else p = p->rotate_right(); } } p->update(); return p; } Node *remove(Node *p, int val) { if (!p) return NULL; if (val < p->val) { p->l = remove(p->l, val); if (p->balance() == -2) { if (p->l->balance() <= 0) p = p->rotate_right(); else p = p->rotate_left_right(); } } else if (val > p->val) { p->r = remove(p->r, val); if (p->balance() == 2) { if (p->r->balance() >= 0) p = p->rotate_left(); else p = p->rotate_right_left(); } } else { if (!p->l && !p->r) { delete p; return NULL; } else if (p->l) { Node *q = p->l; while (q->r) q = q->r; p->val = q->val; p->l = remove(p->l, q->val); if (p->balance() == -2) { if (p->l->balance() <= 0) p = p->rotate_right(); else p = p->rotate_left_right(); } } else { Node *q = p->r; while (q->l) q = q->l; p->val = q->val; p->r = remove(p->r, q->val); if (p->balance() == 2) { if (p->r->balance() >= 0) p = p->rotate_left(); else p = p->rotate_right_left(); } } } p->update(); return p; } void inorder(Node *p) { if (!p) return; inorder(p->l); cout << p->val << " "; inorder(p->r); } }; ``` 这里实现AVL平衡的插入、删除和中序遍历操作。其中,AVL平衡的旋转操作被封装在了节点结构体中,包括左旋、右旋、左右旋和右左旋四种情况。具体实现时,需要注意节点的高度、平衡因子的计算和更新,以及对的递归遍历等细节问题。

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值