参考数据结构与算法 Java语言描述 实现平衡二叉树
思想主要是节点保存高度信息,
在remove和insert递归返回过程中检查是否破坏平衡然后通过单旋转或双旋转重新平衡
实现代码如下:
public class AvlTree<T extends Comparable<T>> {
AvlNode root;
public AvlTree(){
root = null;
}
public void insert(T value){
root = insert(value, root);
}
private AvlNode<T> insert(T value,AvlNode<T> node){
if(node == null){
return new AvlNode<T>(value);
}
int result = value.compareTo(node.value);
if(result < 0){
node.left = insert(value, node.left);
if(getHeight(node.left)-getHeight(node.right) == 2){
if( value.compareTo(node.left.value) <0){
node = rotateWithLeftChild(node);
}else {
node = doubleWithLeftChild(node);
}
}
} else{
node.right = insert(value, node.right);
if(getHeight(node.right)-getHeight(node.left) == 2){
if( value.compareTo(node.right.value) > 0){
node = rotateWithRightChild(node);
}else {
node = doubleWithRightChild(node);
}
}
}
/* 重新确定到根节点路径的节点高度*/
node.height = Math.max(getHeight(node.left), getHeight(node.right))+1;
return node;
}
//删除方法
public void remove(T value){
root = remove(value,root);
}
private AvlNode<T> remove(T value,AvlNode<T> node){
if(node == null){
return node;
}
int result = value.compareTo(node.value);
if(result < 0){
node.left = remove(value, node.left);
}else if(result > 0){
node.right = remove(value, node.right);
}else if(node.left != null){
node.value = findMax(node.left).value;
node.left = remove(node.value, node.left);
}else if(node.right != null){
node.value = findMin(node.right).value;
node.right = remove(node.value, node.right);
}else {
return null;
}
node.height = Math.max(getHeight(node.left), getHeight(node.right))+1;
if(getHeight(node.left)-getHeight(node.right) == 2){
if(getHeight(node.left.left)-getHeight(node.left.right) == 1){
node = rotateWithLeftChild(node);
}else {
node = doubleWithLeftChild(node);
}
}else if(getHeight(node.right)-getHeight(node.left) == 2){
if(getHeight(node.right.right)-getHeight(node.right.left) == 1){
node = rotateWithRightChild(node);
}else {
node = doubleWithRightChild(node);
}
}
return node;
}
public T findMin(){
return findMin(root).value;
}
public T findMax(){
return findMax(root).value;
}
private AvlNode<T> findMin(AvlNode<T> node){
if(node == null){
return node;
}
if(node.left!=null)
{
return findMin(node.left);
}
return node;
}
private AvlNode<T> findMax(AvlNode<T> node){
if(node == null){
return node;
}
if(node.right!=null)
{
return findMax(node.right);
}
return node;
}
public void printTree(){
printTree(root);
}
private void printTree(AvlNode<T> node){
if(node != null){
printTree(node.left);
System.out.println(node.value+" ");
printTree(node.right);
}
}
/**
* 单旋转 旋转同时调整高度
* @retun 旋转后新的根节点
*/
private AvlNode<T> rotateWithLeftChild(AvlNode<T> node){
AvlNode<T> k2 = node.left;
node.left = k2.right;
k2.right = node;
node.height = Math.max(getHeight(node.left),getHeight(node.right))+1;
k2.height = Math.max(getHeight(k2.left),getHeight(k2.right))+1;
return k2;
}
/**
* 单旋转 旋转同时调整高度
* @retun 旋转后新的根节点
*/
private AvlNode<T> rotateWithRightChild(AvlNode<T> node){
AvlNode<T> k2 = node.right;
node.right = k2.left;
k2.left = node;
node.height = Math.max(getHeight(node.left),getHeight(node.right))+1;
k2.height = Math.max(getHeight(k2.left),getHeight(k2.right))+1;
return k2;
}
/**
* 双旋转
* @retun 旋转后新的根节点
*/
private AvlNode<T> doubleWithLeftChild(AvlNode<T> node){
node.left = rotateWithRightChild(node.left);
return rotateWithLeftChild(node);
}
/**
* 双旋转
* @retun 旋转后新的根节点
*/
private AvlNode<T> doubleWithRightChild(AvlNode<T> node){
node.right = rotateWithLeftChild(node.right);
return rotateWithRightChild(node);
}
private int getHeight(AvlNode<T> node){
return node == null ? -1 : node.height;
}
private static class AvlNode<T>{
T value;
AvlNode<T> left;
AvlNode<T> right;
int height; //节点的高度
public AvlNode(T value) {
this(value,null,null);
}
public AvlNode(T value, AvlNode<T> left, AvlNode<T> right) {
this.value = value;
this.left = left;
this.right = right;
height = 0;
}
}
}
测试代码:
import org.junit.Test;
public class AvlTreeTest {
@Test
public void test() {
AvlTree<Integer> tree = new AvlTree<>();
tree.insert(1);
tree.insert(2);
tree.insert(3);
tree.insert(4);
tree.insert(5);
tree.insert(6);
tree.insert(7);
tree.insert(8);
tree.insert(9);
// tree.printTree();
System.out.println(tree.findMin());
System.out.println(tree.findMax());
tree.remove(5);
tree.remove(4);
tree.remove(3);
tree.printTree();
}
}
删除部分代码测试不是很完善,欢迎各位拍砖指正!