AVL树是排序二叉树的优化版,多了个调整操作,排序二叉树在某些情况下可能会变的跟链表差不多,比如连续插入多个非递减的数:
这就会使得各种操作非常费时间,几乎和链表一样。但是AVL树避免了这一点,保证了任意一个节点的左右儿子节点的高度差不会超过1,这就使得AVL树的复杂度很平衡(插入删除查找log(n))。
这里的调整操作是基于两个旋转操作来进行的,即左旋和右旋:
右旋:
左旋(实际上就是右旋的反转):
当出现某个节点的左右儿子的高度差距太大时,就要进行调整使左右平衡。当出现类似 这种左儿子高度比右儿子高度大2且左儿子的左儿子比左儿子的右儿子高度大1的时候这种暂且叫做LL,然后我们需要一个右旋操作来进行平衡,同理RR(其实是因为作图实在太丑了就不想做了.)就是左儿子高度比右儿子高度小2且左儿子的左儿子比左儿子的右儿子高度小1。然后就是LR跟RL(两种情况一样):
从图就可以看出来,这种情况需要进行先左旋然后右旋两个旋转操作,RL就是先右旋然后左旋。经过这样的平衡操作之后二叉树就会处于一个比较平衡的状态。
首先是节点的定义:
struct node {
node *left, *right;
int val; // 值
int h; // 节点高度
};
必要函数:
int height(node *rt) { // 获取某个节点的高度
if (!rt) return 0;
return rt->h;
}
node *getNewNode(int val) { // 新建节点
node *rt = new node;
rt->left = rt->right = NULL;
rt->h = 1;
rt->val = val;
return rt;
}
node *singleLeftRotate(node *a) { // 左旋操作
node *b = a->right;
a->right = b->left;
b->left = a;
a->h = max(height(a->left), height(a->right)) + 1;
b->h = max(height(b->left), height(b->right)) + 1;
return b;
}
node *singleRightRotate(node *a) { // 右旋操作
node *b = a->left;
a->left = b->right;
b->right = a;
a->h = max(height(a->left), height(a->right)) + 1;
b->h = max(height(b->left), height(b->right)) + 1;
return b;
}
void adjust(node* &rt) { // 调整函数,经过的每个节点都要进行检查
rt->h = max(height(rt->left), height(rt->right)) + 1;
if (height(rt->left) - height(rt->right) == 2) {
if (height(rt->left->left) - height(rt->left->right) == 1) { // LL
rt = singleRightRotate(rt);
} else { // LR
rt->left = singleLeftRotate(rt->left);
rt = singleRightRotate(rt);
}
} else if (height(rt->right) - height(rt->left) == 2) {
if (height(rt->right->right) - height(rt->right->left) == 1) { // RR
rt = singleLeftRotate(rt);
} else { // RL
rt->right = singleRightRotate(rt->right);
rt = singleLeftRotate(rt);
}
}
rt->h = max(height(rt->left), height(rt->right)) + 1;
}
以上是AVL树的核心函数,然后就是插入删除操作:
node *Insert(node* &rt, int val) {
if (!rt) rt = getNewNode(val);
else if (val > rt->val) rt->right = Insert(rt->right, val);
else if (val < rt->val) rt->left = Insert(rt->left, val);
else { }
adjust(rt);
return rt;
}
node *deleteNode(node *rt, int val) { // 删除操作稍微麻烦一点,原理就是一旦找到要删除的值,将这个数与比它小的数中最小的(实际上就是从这个节点的左节点往下,每次都走左节点直到不能走 最大一个同理)或者比它大的数中最大的互换位置,然后这个数就到了叶子节点上去了,然后就直接删除就行了
if (val > rt->val) rt->right = deleteNode(rt->right, val);
else if (val < rt->val) rt->left = deleteNode(rt->left, val);
else {
if (rt->left) {
node* dn = NULL;
for (dn = rt->left; NULL != dn->right; dn = dn->right);
rt->val = dn->val;
rt->left = deleteNode(rt->left, dn->val);
} else if (NULL != rt->right) {
node* dn = NULL;
for (dn = rt->right; NULL != dn->left; dn = dn->left);
rt->val = dn->val;
rt->right = deleteNode(rt->right, dn->val);
} else {
free(rt);
return NULL;
}
}
adjust(rt);
return rt;
}
我觉得代码非常好理解的。。
#include <iostream>
#include <algorithm>
#include <cstdio>
#include <cstring>
#include <cstdlib>
using namespace std;
struct node {
node *left, *right;
int val;
int h;
};
int height(node *rt) {
if (!rt) return 0;
return rt->h;
}
node *getNewNode(int val) {
node *rt = new node;
rt->left = rt->right = NULL;
rt->h = 1;
rt->val = val;
return rt;
}
node *singleLeftRotate(node *a) {
node *b = a->right;
a->right = b->left;
b->left = a;
a->h = max(height(a->left), height(a->right)) + 1;
b->h = max(height(b->left), height(b->right)) + 1;
return b;
}
node *singleRightRotate(node *a) {
node *b = a->left;
a->left = b->right;
b->right = a;
a->h = max(height(a->left), height(a->right)) + 1;
b->h = max(height(b->left), height(b->right)) + 1;
return b;
}
struct node *root;
void adjust(node* &rt) {
rt->h = max(height(rt->left), height(rt->right)) + 1;
if (height(rt->left) - height(rt->right) == 2) {
if (height(rt->left->left) - height(rt->left->right) == 1) {
rt = singleRightRotate(rt);
} else {
rt->left = singleLeftRotate(rt->left);
rt = singleRightRotate(rt);
}
} else if (height(rt->right) - height(rt->left) == 2) {
if (height(rt->right->right) - height(rt->right->left) == 1) {
rt = singleLeftRotate(rt);
} else {
rt->right = singleRightRotate(rt->right);
rt = singleLeftRotate(rt);
}
}
rt->h = max(height(rt->left), height(rt->right)) + 1;
}
node *Insert(node* &rt, int val) {
if (!rt) rt = getNewNode(val);
else if (val > rt->val) rt->right = Insert(rt->right, val);
else if (val < rt->val) rt->left = Insert(rt->left, val);
else { }
adjust(rt);
return rt;
}
node *deleteNode(node *rt, int val) {
if (val > rt->val) rt->right = deleteNode(rt->right, val);
else if (val < rt->val) rt->left = deleteNode(rt->left, val);
else {
if (rt->left) {
node* dn = NULL;
for (dn = rt->left; NULL != dn->right; dn = dn->right);
rt->val = dn->val;
rt->left = deleteNode(rt->left, dn->val);
} else if (NULL != rt->right) {
node* dn = NULL;
for (dn = rt->right; NULL != dn->left; dn = dn->left);
rt->val = dn->val;
rt->right = deleteNode(rt->right, dn->val);
} else {
free(rt);
return NULL;
}
}
adjust(rt);
return rt;
}
void dfs(node *rt) {
if (rt->left) dfs(rt->left);
printf("%d ", rt->val);
if (rt->right) dfs(rt->right);
}
int m, n, val;
int main() {
// freopen("out.txt", "r", stdin);
scanf("%d", &n);
for (int i = 0; i < n; ++i) {
scanf("%d", &val);
if (i == 0) root = getNewNode(val);
else root = Insert(root, val);
}
dfs(root);
printf("\n");
scanf("%d", &m);
for (int i = 0; i < m; ++i) {
scanf("%d", &val);
root = deleteNode(root, val);
}
dfs(root);
printf("\n");
return 0;
}