这两天终于把AVL树好好理解了下,在《算法分析与设计基础》这本书中,被安排在变治法章节,是实例简化思想在查找树中的应用。它对平衡的要求是:每个节点的左右子树的高度差不超过1。从而我们只要在插入或删除节点时,保证这种平衡就可以了。如果平衡被打破,使用一系列旋转使树重新达到平衡。
总共有四种类型的旋转:左单转,右单转,左右双转,右左双转。只要讲到AVL树的算法书都会有旋转的解释。两个单转之间、两个双转之间相互对称。看似复杂的双转,其思想是先转换成单转,从而通过单转实现重新平衡。所以四种旋转的实现可以非常简洁。
由于每次旋转,我们都能把树的高度恢复到插入前的水平,所以当平衡被打破时,只要一次旋转就足以解决问题。而需要旋转的是插入节点所在的,左右子树高度差大于1的最小子树。
因为AVL树是平衡查找树,先从普通的二叉查找树开始理解更容易记忆。
二叉树的插入使用递归的方式进行,可以非常清晰的表达算法思想。需要注意的是,java实现中无法使用指针,导致与c++相比,需要注意将插入的节点挂到原来的树上去。
public class BinarySortTree {
//节点结构
public static class BinaryTreeNode{
int v,height;
BinaryTreeNode leftChild,rightChild;
public BinaryTreeNode(int v){
this.v = v;
this.leftChild = null;
this.rightChild = null;
this.height = 0;
}
public BinaryTreeNode(int v, BinaryTreeNode leftChild,
BinaryTreeNode rightChild,int height) {
super();
this.v = v;
this.leftChild = leftChild;
this.rightChild = rightChild;
this.height = height;
}
}
//需要处理空节点,为空时高为-1
public static int height(BinaryTreeNode node){
return node == null ? -1:node.height;
}
private BinaryTreeNode root;
public BinaryTreeNode getRoot() {
return root;
}
public void insert(int value){
this.root = insert(value,this.root);
}
//递归插入
public BinaryTreeNode insert(int value,BinaryTreeNode t){
if(t == null){
return new BinaryTreeNode(value);
}
//插入值与当前节点比较,小于插入到左子树,大于插入到右子树
if(value
t.leftChild = insert(value,t.leftChild);
}else if(value > t.v){
t.rightChild = insert(value,t.rightChild);
}else{/*equal,do nothing*/}
//更新高度
t.height = Math.max(height(t.leftChild), height(t.rightChild)) + 1;
return t;
}
}
有了这个基础,然后再来实现平衡。所需的工作就是实现4中旋转,并且重写插入方法,在插入的时候,如果平衡被打破则进行旋转重新恢复平衡。
public class AvlTree extends BinarySortTree{
public BinaryTreeNode rotateWithLeftChild(BinaryTreeNode k2){
BinaryTreeNode k1 = k2.leftChild;
k2.leftChild= k1.rightChild;
k1.rightChild = k2;
//重新计算高度
k2.height = Math.max(height(k2.leftChild),height(k2.rightChild)) + 1;
k1.height = Math.max(height(k1.leftChild), k2.height) +1;
return k1;
}
public BinaryTreeNode rotateWithRightChild(BinaryTreeNode k2){
BinaryTreeNode k1 = k2.rightChild;
k2.rightChild= k1.leftChild;
k1.leftChild = k2;
//重新计算高度
k2.height = Math.max(height(k2.leftChild),height(k2.rightChild)) + 1;
k1.height = Math.max(height(k1.rightChild), k2.height) +1;
return k1;
}
public BinaryTreeNode doubleWithLeftChild(BinaryTreeNode k3){
k3.leftChild = rotateWithRightChild(k3.leftChild);
return rotateWithLeftChild(k3);
}
public BinaryTreeNode doubleWithRightChild(BinaryTreeNode k3){
k3.rightChild = rotateWithLeftChild(k3.rightChild);
return rotateWithRightChild(k3);
}
@Override
public BinaryTreeNode insert(int value, BinaryTreeNode t) {
if(t == null){
return new BinaryTreeNode(value);
}
//插入值与当前节点比较,小于插入到左子树,大于插入到右子树
if(value
t.leftChild = insert(value,t.leftChild);
//判断平衡是否被破坏
if(height(t.leftChild) - height(t.rightChild) == 2){
if(value
t = rotateWithLeftChild(t);
}else if(value > t.leftChild.v){
t = doubleWithLeftChild(t);
}else{/*impossible do nothing*/}
}
}else if(value > t.v){
t.rightChild = insert(value,t.rightChild);
//判断平衡是否被破坏
if(height(t.rightChild) - height(t.leftChild) == 2){
if(value > t.rightChild.v){
t = rotateWithRightChild(t);
}else if(value
t = doubleWithRightChild(t);
}else{/*impossible do nothing*/}
}
}else{/*equal,do nothing*/}
//更新高度
t.height = Math.max(height(t.leftChild), height(t.rightChild)) + 1;
return t;
}
}
使用groovy进行单元测试,同时使用中序遍历输出,更加直观:
import static org.junit.Assert.*;
import org.junit.Before;
import org.junit.Test;
import binaryTree.BinarySortTree.BinaryTreeNode;
class TestBinaryTree {
int[] date = [1,6,4,3,9,2,8,7]
@Test
public void testBinarySortTree() {
BinarySortTree tree = new BinarySortTree();
date.each{
tree.insert(it);
}
show(tree);
assertEquals tree.getRoot().rightChild.leftChild.v,4
}
@Test
public void testAvlTree(){
AvlTree tree = new AvlTree();
date.each{
tree.insert(it);
}
show(tree);
assertEquals tree.getRoot().leftChild.rightChild.v,3
tree.insert(5)
show(tree);
assertEquals tree.getRoot().rightChild.leftChild.leftChild.v,5
}
private static void show(BinarySortTree t){
print(t.getRoot(),0)
}
private static void print(BinaryTreeNode root,int depth){
if(root != null){
print(root.leftChild,depth+1);
printPrefix(depth);
System.out.println(root.v);
print(root.rightChild,depth+1);
}
}
static final String PREFIX = " ";
private static void printPrefix(int times){
while(times-->0){
if(times == 0)
System.out.print("+---");
else
System.out.print(PREFIX);
}
}
}
参考资料: