思路:
分析总结插入后引起不平衡状况的几个例子
1: 插入到树的外围, 左子点的左子点, 右子点的右子点, 使用单旋转;
2: 插入到树的内围, 左子点的右子点, 右子点的左子点, 使用双旋转;
3: 使用递归调用,更新沿路的height(用于平衡,左右子树不超过2);
具体情况见 <数据结构与算法分析>P82
#include "string.h"
#include "stdlib.h"
#include "stdio.h"
typedef int element_type;
struct avl_node;
typedef struct avl_node *avl_tree;
typedef struct avl_node *position; //for rotate
typedef struct avl_node {
element_type element;
avl_tree left;
avl_tree right;
int height;
} avl_node_t;
#define MAX(x,y) ((x) > (y)? (x): (y))
int height(position p)
{
if (p == NULL)
return -1;
else {
return p->height;
}
}
avl_tree single_rotate_with_left(position k2)
{
position k1 = k2->left;
k2->left = k1->right;
k1->right = k2;
k2->height = MAX(height(k2->left),height(k2->right)) + 1;
k1->height = MAX(height(k1->left),height(k1->right)) + 1;
return k1;
}
avl_tree single_rotate_with_right(position k1)
{
position k2 = k1->right;
k1->right = k2->left;
k2->left = k1;
k1->height = MAX(height(k1->left),height(k1->right)) + 1;
k2->height = MAX(height(k2->left),height(k2->right)) + 1;
return k2;
}
avl_tree double_rotate_with_left(position k3)
{
position k1 = k3->left;
position k2 = k1->right;
/* rotate between k1 k2 */
k3->left = single_rotate_with_right(k1);
return single_rotate_with_left(k3); //k2
}
avl_tree double_rotate_with_right(position k1)
{
position k3 = k1->right;
position k2 = k3->left;
/* rotate between k2 k3 */
k1->right = single_rotate_with_left(k3);
return single_rotate_with_right(k1); //k2
}
avl_tree insert(element_type element, avl_tree tree)
{
if (tree == NULL) {
/* creat leaf */
tree = (avl_tree)malloc(sizeof(avl_node_t));
if (tree == NULL)
{
printf("malloc error\n");
return NULL;
}
tree->element = element;
tree->left = tree->right = NULL;
tree->height = 0;
}
if (element > tree->element)
{
tree->right = insert(element, tree->right);
if (height(tree->right) - height(tree->left) == 2) {
if (element > tree->right->element) { //insert outside of tree
tree = single_rotate_with_right(tree);
}
else { //insert inside of tree
tree = double_rotate_with_right(tree);
}
}
}
else if (element < tree->element)
{
tree->left = insert(element, tree->left);
if (height(tree->left) - height(tree->right) == 2) {
if (element < tree->left->element) { //insert outside of tree
tree = single_rotate_with_left(tree);
}
else { //insert inside of tree
tree = double_rotate_with_left(tree);
}
}
}
else {
/* exist do nothing */
}
//udpate height every call
tree->height = MAX(height(tree->left), height(tree->right)) + 1;
return tree;
}
int get_tree_depth(avl_tree root)
{
if (root == NULL) return 0;
int left_depth = get_tree_depth(root->left);
int right_depth = get_tree_depth(root->right);
return MAX(left_depth,right_depth) + 1;//child depth + cur depth 1
}
static int print_pos_x;
static int print_tree_high;
#define ELMENT_PRINT_SPACE (12)
#define ELMENT_PRINT_OFFSET(m,l) (((m) - (l)) / 2)
#define ELMENT_PRINT_MAX_DEPTH (6)
#define CALC_PRINT_ELMENT_TAKE_PLACE(l,x) ((1 << l) * x)
#define ELMENT_PRINT_MAX_DEPTH_SPACE CALC_PRINT_ELMENT_TAKE_PLACE(ELMENT_PRINT_MAX_DEPTH, ELMENT_PRINT_SPACE)
char print_buffer[ELMENT_PRINT_MAX_DEPTH][ELMENT_PRINT_MAX_DEPTH_SPACE];
void print_core_tree(avl_node_t* node, int level)
{
if (node == NULL) {
print_pos_x += ELMENT_PRINT_SPACE;
return;
}
print_core_tree(node->left, level + 1);
char str_element[12];
sprintf(str_element, "%d(%d)", node->element, node->height);
int len = strlen(str_element);
memcpy(&print_buffer[level][print_pos_x + ELMENT_PRINT_OFFSET(ELMENT_PRINT_SPACE, len)], str_element, len);
print_pos_x += ELMENT_PRINT_SPACE;
print_core_tree(node->right, level + 1);
}
void print_tree(avl_tree root)
{
if (root == NULL) return;
int i,j;
print_pos_x = ELMENT_PRINT_SPACE / 2;
print_tree_high = get_tree_depth(root);
if (print_tree_high > ELMENT_PRINT_MAX_DEPTH) {
printf("warning: only support %d depth\n", ELMENT_PRINT_MAX_DEPTH);
}
for (i = 0; i < ELMENT_PRINT_MAX_DEPTH; i++)
{
for (j = 0; j < ELMENT_PRINT_MAX_DEPTH_SPACE; j++)
{
print_buffer[i][j] = 0x7f;
}
}
print_core_tree(root, 0);
for (i = 0; i < print_pos_x; i++) printf("=");
printf("\n");
for (i = 0; i < print_tree_high; i++)
{
for (j = 0; j < print_pos_x; j++)
{
if (print_buffer[i][j] == 0x7f) printf(" ");
else printf("%c", print_buffer[i][j]);
}
printf("\n\n");
}
for (i = 0; i < print_pos_x; i++) printf("=");
printf("\n");
}
int main(int argc, char* argv[])
{
avl_tree tree = NULL;
element_type element;
printf("creat avl\n");
while(1) {
scanf("%d", &element);
if (element == 0) break;
tree = insert(element, tree);
print_tree(tree);
}
}
运行结果