代码
#include <bits/stdc++.h>
using namespace std;
struct Node
{
int val;
Node *left, *right;
Node(int v):val(v),left(nullptr),right(nullptr){}
};
struct AVLTree
{
Node *root;
AVLTree():root(nullptr){}
int get_height(Node *node)
{
if(node == nullptr)
return -1;
return max(get_height(node->left), get_height(node->right)) + 1;
}
Node* left_rotation(Node *x)
{
Node *y = x->right;
x->right = y->left;
y->left = x;
return y;
}
Node* right_rotation(Node *x)
{
Node *y = x->left;
x->left = y->right;
y->right = x;
return y;
}
int get_balance(Node *node)
{
if(node == nullptr)
return 0;
int hleft = get_height(node->left);
int hright = get_height(node->right);
return hleft - hright;
}
Node* insert(Node* node, const int &v)
{
if(node == nullptr)
{
node = new Node(v);
}
else if(v < node->val)
{
node->left = insert(node->left, v);
}
else if(v > node->val)
{
node->right = insert(node->right, v);
}
else
return node;
int balance = get_balance(node);
if(balance > 1)
{
if(v < node->left->val)
{
return right_rotation(node);
}
else if(v > node->left->val)
{
node->left = left_rotation(node->left);
return right_rotation(node->right);
}
}
else if(balance < -1)
{
if(v > node->right->val)
{
return left_rotation(node);
}
else if(v < node->right->val)
{
right_rotation(node->right);
return left_rotation(node->left);
}
}
return node;
}
void insert(const int &v)
{
root = insert(root, v);
}
Node* successor(Node *node)
{
while(node->left!=nullptr)
node = node->left;
return node;
}
Node* avl_delete(Node *node, const int &v)
{
if(node == nullptr)
return node;
else if(v < node->val)
{
node->left = avl_delete(node->left, v);
}
else if(v > node->val)
{
node->right = avl_delete(node->right, v);
}
else
{
Node *tmp = nullptr;
if(node->left != nullptr && node->right != nullptr)
{
tmp = successor(node->right);
node->right = avl_delete(node->right, tmp->val);
node->val = tmp->val;
}
else if(node->left == nullptr)
{
tmp = node;
node = node->right;
delete tmp;
}
else if(node->right == nullptr)
{
tmp = node;
node = node->left;
delete tmp;
}
else
delete node;
}
if(node == nullptr)
return node;
int balance = get_balance(node);
if(balance > 1)
{
if(get_balance(node->left) >= 0)
return right_rotation(node);
else
{
node->left = left_rotation(node->left);
return right_rotation(node);
}
}
else if(balance < -1)
{
if(get_balance(node->right) <= 0)
return left_rotation(node);
else
{
node->right = right_rotation(node->right);
return left_rotation(node);
}
}
return node;
}
void avl_delete(const int &v)
{
root = avl_delete(root, v);
}
void in_order(Node* node)
{
if(node == nullptr)
return ;
in_order(node->left);
cout << node->val << " ";
in_order(node->right);
}
void in_order()
{
cout << "in_order: ";
in_order(root);
cout << endl;
}
void pre_order(Node* node)
{
if(node == nullptr)
return ;
cout << node->val << " ";
pre_order(node->left);
pre_order(node->right);
}
void pre_order()
{
cout << "pr_order: ";
pre_order(root);
cout << endl;
}
};
int main()
{
AVLTree T;
vector<int> vec{1,2,3,4,5,6,7,8,9};
cout << "insert: " << endl;
for(int i=0;i<int(vec.size());i++)
{
T.insert(vec[i]);
T.in_order();
T.pre_order();
}
cout << "delete: " << endl;
for(int i=vec.size()-1;i>=0;i--)
{
T.avl_delete(vec[i]);
T.in_order();
T.pre_order();
}
return 0;
}
测试结果
insert:
in_order: 1
pr_order: 1
in_order: 1 2
pr_order: 1 2
in_order: 1 2 3
pr_order: 2 1 3
in_order: 1 2 3 4
pr_order: 2 1 3 4
in_order: 1 2 3 4 5
pr_order: 2 1 4 3 5
in_order: 1 2 3 4 5 6
pr_order: 4 2 1 3 5 6
in_order: 1 2 3 4 5 6 7
pr_order: 4 2 1 3 6 5 7
in_order: 1 2 3 4 5 6 7 8
pr_order: 4 2 1 3 6 5 7 8
in_order: 1 2 3 4 5 6 7 8 9
pr_order: 4 2 1 3 6 5 8 7 9
delete:
in_order: 1 2 3 4 5 6 7 8
pr_order: 4 2 1 3 6 5 8 7
in_order: 1 2 3 4 5 6 7
pr_order: 4 2 1 3 6 5 7
in_order: 1 2 3 4 5 6
pr_order: 4 2 1 3 6 5
in_order: 1 2 3 4 5
pr_order: 4 2 1 3 5
in_order: 1 2 3 4
pr_order: 2 1 4 3
in_order: 1 2 3
pr_order: 2 1 3
in_order: 1 2
pr_order: 2 1
in_order: 1
pr_order: 1
in_order:
pr_order: