典型的AVL的插入旋转问题见博客
主要注意四种方式下 中间节点的左右孩子到底应该怎么变换.
#include<bits/stdc++.h>
using namespace std;
const int maxn = 25;
struct node
{
node * l;
node * r;
int value;
node()
{
l = r = nullptr;
}
};
int get_height(node *root)
{
if(root == nullptr)
return 0;
return max(get_height(root -> l) ,get_height(root -> r)) + 1;
}
node* LL(node *root)
{
node *tmp;
tmp = root -> l;
root -> l = tmp -> r;
tmp -> r = root;
return tmp;
}
node* RR(node *root)
{
node *tmp;
tmp = root -> r;
root -> r = tmp -> l;
tmp -> l = root;
return tmp;
}
node* LR(node *root)
{
root -> l = RR(root -> l);
return LL(root);
}
node *RL(node *root)
{
root -> r = LL(root -> r);
return RR(root);
}
node* Insert(node *root,int value)
{
if(root == nullptr)
{
node *tmp = new node;
tmp -> value = value;
return tmp;
}
else if(value < root -> value)
{
root -> l = Insert(root -> l,value);
if(get_height(root -> l) - get_height(root -> r) == 2)
{
if(value < root -> l -> value)
root = LL(root);
else
root = LR(root);
}
}
else
{
root -> r = Insert(root -> r,value);
if(get_height(root -> r) - get_height(root -> l) == 2)
{
if(value < root -> r -> value)
root = RL(root);
else
root = RR(root);
}
}
return root;
}
int main()
{
int n;
scanf("%d",&n);
node *root = nullptr;
for(int i = 1;i <= n;++i)
{
int a;
scanf("%d",&a);
root = Insert(root,a);
}
printf("%d\n",root -> value);
}
#include<bits/stdc++.h>
using namespace std;
const int maxn = 25;
struct node
{
node * l;
node * r;
int value;
node()
{
l = r = nullptr;
}
};
int get_height(node *root)
{
if(root == nullptr)
return 0;
return max(get_height(root -> l) ,get_height(root -> r)) + 1;
}
node* LL(node *root)
{
node *tmp;
tmp = root -> l;
root -> l = tmp -> r;
tmp -> r = root;
return tmp;
}
node* RR(node *root)
{
node *tmp;
tmp = root -> r;
root -> r = tmp -> l;
tmp -> l = root;
return tmp;
}
node* LR(node *root)
{
root -> l = RR(root -> l);
return LL(root);
}
node *RL(node *root)
{
root -> r = LL(root -> r);
return RR(root);
}
node* Fix(node* root)
{
int lheight = get_height(root -> l);
int rheight = get_height(root -> r);
if(lheight > rheight)
{
if(get_height(root -> l -> l) > get_height(root -> l -> r))
root = LL(root);
else
root = LR(root);
}
else
{
if(get_height(root -> r -> l) > get_height(root -> r -> r))
root = RL(root);
else
root = RR(root);
}
return root;
}
node* Insert(node *root,int value)
{
if(root == nullptr)
{
node *tmp = new node;
tmp -> value = value;
return tmp;
}
else if(value < root -> value)
{
root -> l = Insert(root -> l,value);
if(get_height(root -> l) - get_height(root -> r) == 2)
root = Fix(root);
}
else
{
root -> r = Insert(root -> r,value);
if(get_height(root -> r) - get_height(root -> l) == 2)
root = Fix(root);
}
return root;
}
int main()
{
int n;
scanf("%d",&n);
node *root = nullptr;
for(int i = 1;i <= n;++i)
{
int a;
scanf("%d",&a);
root = Insert(root,a);
}
printf("%d\n",root -> value);
}