最近在看数据结构,参考了网上的代码,但是在调试时发现一些问题,于是自己修改了一下。AVL树大多数操作和二叉树一样,这里我只涉及到插入和删除操作。
1. 节点定义
相比较于普通二叉树的节点,多了一个节点高度,需要注意的是,对于空节点,高度默认为0,叶子节点高度为1。节点高度的设置是为了四种旋转操作。节点和树的具体实现如下(使用了类模板):
template<class T>
class AVLnode{
public:
AVLnode(const T& el){
info = el;
left = nullptr;
right = nullptr;
height = 1;
}
AVLnode<T> *left, *right;
T info; //节点储存值
int height;
};
template<class T>
class AVLTree{
public:
AVLTree(){
root = nullptr;
}
//定义左旋操作
AVLnode<T>* L_rotate(AVLnode<T> *p);
//定义右旋操作
AVLnode<T>* R_rotate(AVLnode<T> *p);
//左子树右旋,再左旋
AVLnode<T>* LR_rotate(AVLnode<T> *p);
//右子树左旋,再右旋
AVLnode<T>* RL_rotate(AVLnode<T> *p);
void insert(const T& el);
AVLnode<T>* insert_(AVLnode<T>* &p, const T& el);
AVLnode<T>* search(const T& el) const;
void remove(const T& el);
AVLnode<T>* remove_(AVLnode<T>* &p, AVLnode<T>* cur);
AVLnode<T>* maxleft(AVLnode<T>* p);
AVLnode<T>* minright(AVLnode<T>* p);
//获取某个节点的高度
int GetHeight(AVLnode<T> *p) const{
if(p == nullptr) return 0;
return p->height;
}
AVLnode<T>* root;
};
2. 旋转操作
四种旋转操作是为了平衡删除或者插入节点之后的树,具体操作可以参考这篇博文, 这里不再说明,直接上代码。
左旋
//以p节点为根左旋
template<class T>
AVLnode<T>* AVLTree<T>::L_rotate(AVLnode<T>* p){
AVLnode<T> *l = p->left;
p->left = l->right;
l->right = p;
//更新高度
p->height = max(GetHeight(p->left), GetHeight(p->right)) + 1;
l->height = max(GetHeight(l->left), GetHeight(l->right)) + 1;
//返回新的根节点l
return l;
}
右旋
template<class T>
AVLnode<T>* AVLTree<T>::R_rotate(AVLnode<T>* p){
AVLnode<T> *l = p->right;
p->right = l->left;
l->left = p;
p->height = max(GetHeight(p->left), GetHeight(p->right)) + 1;
l->height = max(GetHeight(l->left), GetHeight(l->right)) + 1;
return l;
}
左子树右旋,再左旋
template<class T>
AVLnode<T>* AVLTree<T>::LR_rotate(AVLnode<T>* p){
AVLnode<T> *l = p->left;
p->left = R_rotate(l);
return L_rotate(p);
}
右子树左旋,再右旋
template<class T>
AVLnode<T>* AVLTree<T>::RL_rotate(AVLnode<T>* p){
AVLnode<T> *l = p->right;
p->right = L_rotate(l);
return R_rotate(p);
}
3. 插入操作
总体思想是通过递归查找到插入位置,然后回溯判断是否需要对插入后的树进行平衡(即旋转操作),详细原理可参考上述博文或者数据结构书,这里直接列出代码:
template<class T>
void AVLTree<T>::insert(const T& el){
insert_(root, el);
}
template<class T>
AVLnode<T>* AVLTree<T>::insert_(AVLnode<T>* &p, const T& el){
if(p == nullptr){
p = new AVLnode<T>(el);
if (p == nullptr)
cout<<"failed to insert a new node"<<endl;
}
else if(el < p->info){
p->left = insert_(p->left, el);
//回溯判断新树是否失衡
if(GetHeight(p->left) - GetHeight(p->right) > 1){
if(el < p->left->info) p = L_rotate(p);
else p = LR_rotate(p);
}
}
else if(el > p->info){
p->right = insert_(p->right, el);
if(GetHeight(p->right) - GetHeight(p->left) > 1){
if(el > p->right->info) p = R_rotate(p);
else p = RL_rotate(p);
}
}
else cout<<"this node is in the tree already!"<<endl;
//更新节点高度
p->height = max(GetHeight(p->left), GetHeight(p->right)) + 1;
return p;
}
4. 删除操作
AVL树的删除操作我个人理解:部分沿用了普通二叉树的复制删除操作,把需要删除的节点赋一个新的值,然后把原来树中这个值的节点删掉,所以最终实际删除的节点一定是叶子节点或者只有一个子节点的非叶子节点,需要被删除的节点只是值被替换掉了,同样需要有回溯来平衡新树,代码如下:
template<class T>
void AVLTree<T>::remove(const T& el){
//search接口与普通二叉树定义一样,返回指向该节点的指针
AVLnode<T>* cur = search(el);
if(cur != nullptr)
root = remove_(root, cur);
}
template<class T>
AVLnode<T>* AVLTree<T>::remove_(AVLnode<T>* &p, AVLnode<T>* cur){
if(cur->info < p->info){
p->left = remove_(p->left, cur);
if(GetHeight(p->right) - GetHeight(p->left) > 1){
if(GetHeight(p->right->right) > GetHeight(p->right->left))
p = R_rotate(p);
else p = RL_rotate(p);
}
}
else if(cur->info > p->info){
p->right = remove_(p->right, cur);
if(GetHeight(p->left) - GetHeight(p->right) > 1){
if(GetHeight(p->left->left) > GetHeight(p->left->right))
p = L_rotate(p);
else p = LR_rotate(p);
}
}
//找到被删除节点的位置
else{
if(p->left!=nullptr && p->right!=nullptr){
if(GetHeight(p->left) > GetHeight(p->right)){
//maxleft返回指向左子树的最大值节点,这里类似于复制删除
AVLnode<T>* tem = maxleft(p->left);
p->info = tem->info;
//转换成删除另外一个节点
remove_(p->left, tem);
}
else{
//minright返回右子树的最小值节点
AVLnode<T>* tem = minright(p->right);
p->info = tem->info;
remove_(p->right, tem);
}
}
else{
//实际的删除操作发生在这里
AVLnode<T>* tem = p;
p = (p->left == nullptr) ? p->right : p->left;
delete tem;
}
}
//更新高度,注意这里若p是空就不需要更新了,
//直接回溯到上一层递归更新高度即可
if(p != nullptr)
p->height = max(GetHeight(p->left), GetHeight(p->right)) + 1;
return p;
}
template<class T>
AVLnode<T>* AVLTree<T>::maxleft(AVLnode<T>* p){
AVLnode<T>* cur = p, *pre = p;
while(cur != nullptr){
pre = cur;
cur = cur->right;
}
return pre;
}
template<class T>
AVLnode<T>* AVLTree<T>::minright(AVLnode<T>* p){
AVLnode<T>* cur = p, *pre = p;
while(cur != nullptr){
pre = cur;
cur = cur->left;
}
return pre;
}
若有错误之处,敬请批评指正。
参考:AVL树01(c++代码实现)
AVL树(二)之 C++的实现